#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Longformer on FULL Dataset for Fair Comparison with 3D RoBERTa
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import (
    LongformerTokenizer,
    LongformerForSequenceClassification,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import time
import json
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

class IMDBDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Pre-tokenize for speed (like your 3D setup)
        print("Pre-tokenizing dataset for speed...")
        self.encodings = []
        
        for i, text in enumerate(tqdm(texts, desc="Tokenizing")):
            encoding = tokenizer(
                str(text),
                truncation=True,
                padding='max_length',
                max_length=max_length,
                return_tensors='pt'
            )
            self.encodings.append({
                'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten()
            })

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings[idx]['input_ids'],
            'attention_mask': self.encodings[idx]['attention_mask'],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

def load_imdb_data_full():
    """Load FULL IMDB dataset - matching your 3D RoBERTa setup"""
    print("Loading FULL IMDB dataset...")
    dataset = load_dataset('imdb')

    # CHANGE 1: Use FULL training dataset (not 20% subset)
    train_texts = dataset['train']['text']  # All 25,000 samples
    train_labels = dataset['train']['label']
    test_texts = dataset['test']['text']
    test_labels = dataset['test']['label']
    
    # CHANGE 2: Use same test subset as your 3D model for direct comparison
    test_texts, _, test_labels, _ = train_test_split(
        test_texts, test_labels, train_size=0.05, stratify=test_labels, random_state=42
    )

    print(f"Training samples: {len(train_texts)} (FULL DATASET)")
    print(f"Test samples: {len(test_texts)} (matching 3D setup)")
    print(f"Training label distribution: {np.bincount(train_labels)}")
    print(f"Test label distribution: {np.bincount(test_labels)}")

    return train_texts, train_labels, test_texts, test_labels

def create_data_loaders(train_texts, train_labels, test_texts, test_labels,
                       tokenizer, max_length=512, batch_size=8):
    """Create data loaders matching your 3D setup"""
    
    train_dataset = IMDBDataset(train_texts, train_labels, tokenizer, max_length)
    test_dataset = IMDBDataset(test_texts, test_labels, tokenizer, max_length)

    # CHANGE 3: Use same data loading as your 3D setup
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                            num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                           num_workers=2, pin_memory=True)

    return train_loader, test_loader

def train_model_with_timing(model, train_loader, optimizer, scheduler, device, epoch, max_grad_norm=1.0):
    """Train with timing to compare with your 3D model"""
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    start_time = time.time()
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    
    # CHANGE 4: Add mixed precision like your 3D setup
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)

        optimizer.zero_grad()

        # Mixed precision forward pass
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

        scheduler.step()

        total_loss += loss.item()
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

        progress_bar.set_postfix({
            'loss': loss.item(),
            'lr': scheduler.get_last_lr()[0]
        })

    end_time = time.time()
    epoch_time = end_time - start_time

    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, predictions)

    return avg_loss, accuracy, epoch_time

def evaluate_model(model, test_loader, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='Evaluating')

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
            else:
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

            loss = outputs.loss
            logits = outputs.logits

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(true_labels, predictions)

    return avg_loss, accuracy, predictions, true_labels

def main():
    """Main function for fair comparison with 3D RoBERTa"""
    print("\n=== LONGFORMER FULL DATASET COMPARISON ===\n")
    
    # CHANGE 5: Match your 3D hyperparameters for fair comparison
    MAX_LENGTH = 2048       # Match your 3D setup (or try 1024/2048)
    BATCH_SIZE = 8         # Reasonable for Longformer
    LEARNING_RATE = 5e-5   # Match your 3D setup
    EPOCHS = 2             # Match your 3D setup  
    WARMUP_STEPS = 100     # Match your 3D setup
    WEIGHT_DECAY = 0.01    # Match your 3D setup
    MAX_GRAD_NORM = 1.0

    print(f"LONGFORMER HYPERPARAMETERS FOR COMPARISON:")
    print(f"  Max Length: {MAX_LENGTH}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Learning Rate: {LEARNING_RATE}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Dataset: FULL IMDB (25K training samples)")
    print(f"  Mixed Precision: {'Yes' if torch.cuda.is_available() else 'No'}")

    # Load FULL dataset
    train_texts, train_labels, test_texts, test_labels = load_imdb_data_full()

    # Initialize tokenizer and model
    print("Loading Longformer model and tokenizer...")
    tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
    model = LongformerForSequenceClassification.from_pretrained(
        'allenai/longformer-base-4096',
        num_labels=2
    )

    model.to(device)
    print(f"Model loaded with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")

    # Create data loaders
    train_loader, test_loader = create_data_loaders(
        train_texts, train_labels, test_texts, test_labels,
        tokenizer, MAX_LENGTH, BATCH_SIZE
    )

    # Setup optimizer and scheduler
    optimizer = AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        eps=1e-8
    )

    total_steps = len(train_loader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=WARMUP_STEPS,
        num_training_steps=total_steps
    )

    print(f"Starting training for {EPOCHS} epochs...")
    print(f"Total training steps: {total_steps}")

    # Training loop
    best_accuracy = 0
    training_times = []
    
    results = {
        'max_length': MAX_LENGTH,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'epochs': EPOCHS,
        'weight_decay': WEIGHT_DECAY,
        'dataset_size': 'full'
    }

    for epoch in range(EPOCHS):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        print('='*50)

        # Train with timing
        train_loss, train_accuracy, epoch_time = train_model_with_timing(
            model, train_loader, optimizer, scheduler, device, epoch, MAX_GRAD_NORM
        )
        training_times.append(epoch_time)

        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Training time: {epoch_time:.2f}s ({epoch_time/60:.1f} min)")

        # Evaluate
        test_loss, test_accuracy, predictions, true_labels = evaluate_model(
            model, test_loader, device
        )

        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

        # Save best model
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            torch.save(model.state_dict(), 'best_longformer_full_dataset.pt')
            print(f"New best model saved! Accuracy: {best_accuracy:.4f}")

            results.update({
                'best_accuracy': best_accuracy,
                'final_train_loss': train_loss,
                'final_test_loss': test_loss,
                'model': model
            })

    # Calculate average training time
    avg_training_time = np.mean(training_times)
    results['training_time_per_epoch'] = avg_training_time

    # Save comparison results
    comparison_data = {
        'model_name': 'Longformer Full Dataset',
        'dataset': 'IMDB Binary Classification (Full)',
        'hyperparameters': results,
        'performance_metrics': {
            'best_accuracy': best_accuracy,
            'training_time_per_epoch': avg_training_time,
            'total_training_time': sum(training_times)
        }
    }

    with open('longformer_full_comparison.json', 'w') as f:
        json.dump(comparison_data, f, indent=2)

    # Final comparison
    print(f"\n{'='*60}")
    print("FINAL COMPARISON RESULTS")
    print('='*60)
    print(f"✅ Longformer Best Accuracy: {best_accuracy:.4f}")
    print(f"✅ Average Training Time: {avg_training_time:.2f}s ({avg_training_time/60:.1f} min/epoch)")
    print(f"✅ Total Training Time: {sum(training_times)/60:.1f} minutes")
    print(f"✅ Sequence Length: {MAX_LENGTH} tokens")
    print(f"✅ Dataset: Full IMDB (25K samples)")

    print(f"\n🎯 COMPARISON WITH YOUR 3D ROBERTA:")
    print("="*60)
    print(f"3D RoBERTa:  91.68% accuracy, ~5 min/epoch, 512 tokens")
    print(f"Longformer:  {best_accuracy:.2f}% accuracy, ~{avg_training_time/60:.1f} min/epoch, {MAX_LENGTH} tokens")
    
    efficiency_3d = 91.68 / (5 * 512)  # accuracy / (time * tokens)
    efficiency_longformer = best_accuracy / (avg_training_time/60 * MAX_LENGTH)
    
    print(f"\nEfficiency Score (accuracy/time/tokens):")
    print(f"3D RoBERTa:  {efficiency_3d:.6f}")
    print(f"Longformer:  {efficiency_longformer:.6f}")
    
    if efficiency_3d > efficiency_longformer:
        print(f"🏆 3D RoBERTa is {efficiency_3d/efficiency_longformer:.1f}x more efficient!")
    else:
        print(f"📊 Longformer is {efficiency_longformer/efficiency_3d:.1f}x more efficient")

    print("\nClassification Report:")
    print(classification_report(true_labels, predictions,
                              target_names=['Negative', 'Positive']))

    return results, comparison_data

if __name__ == "__main__":
    # Clear cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    results, comparison_data = main()
