#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
 Longformer for Yelp Binary Classification
Matching parameters with 3D RoBERTa for fair comparison
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import time
import json
import warnings
warnings.filterwarnings('ignore')

# Clear datasets cache
import datasets
datasets.disable_caching()

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from transformers import (
    LongformerTokenizer,
    LongformerModel,
    LongformerConfig,
    LongformerForSequenceClassification
)

from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# SPEED OPTIMIZATION 1: Enable optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Set device with optimizations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed
torch.manual_seed(42)
np.random.seed(42)


class OptimizedYelpDataset(Dataset):
    """SPEED OPTIMIZATION 2: Pre-tokenized dataset with caching for Yelp"""
    def __init__(self, texts, labels, tokenizer, max_length=512, cache_dir=None):
        self.labels = labels
        self.max_length = max_length
        
        # Pre-tokenize all data for speed
        print("Pre-tokenizing Yelp dataset for Longformer...")
        self.encodings = []
        
        for i, text in enumerate(tqdm(texts, desc="Tokenizing")):
            encoding = tokenizer(
                str(text),
                truncation=True,
                padding='max_length',
                max_length=max_length,
                return_tensors='pt'
            )
            self.encodings.append({
                'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten()
            })

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings[idx]['input_ids'],
            'attention_mask': self.encodings[idx]['attention_mask'],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }


class OptimizedLongformerForSequenceClassification(nn.Module):
    """
    Optimized Longformer with same optimizations as 3D RoBERTa
    """
    def __init__(self, config, attention_window=512):
        super().__init__()
        self.num_labels = config.num_labels
        self.config = config
        
        # Initialize base Longformer model
        self.longformer = LongformerModel(config)
        
        # Set attention window to match block size from 3D model
        self.longformer.config.attention_window = [attention_window] * config.num_hidden_layers
        
        # OPTIMIZATION: Reduce some layers to match 3D model's partial conversion
        # Keep only 6 layers with full local attention, rest with reduced window
        reduced_window = min(64, attention_window // 8)  # Match 3D block size
        for i in range(6, config.num_hidden_layers):
            self.longformer.config.attention_window[i] = reduced_window
        
        # Classification head
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
        print(f"Longformer attention windows: {self.longformer.config.attention_window}")
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        """Optimized forward pass with mixed precision support"""
        
        # Create global attention mask for [CLS] token (position 0)
        if attention_mask is not None:
            global_attention_mask = torch.zeros_like(input_ids, dtype=torch.long)
            global_attention_mask[:, 0] = 1  # Global attention on [CLS] token
        else:
            global_attention_mask = None
        
        # Forward pass through Longformer
        outputs = self.longformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
            return_dict=True,
            **kwargs
        )
        
        # Get sequence output and apply classification head
        sequence_output = outputs.last_hidden_state
        # Use [CLS] token representation
        cls_output = sequence_output[:, 0, :]  # Shape: [batch_size, hidden_size]
        
        # Apply dropout and classification
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits,
            'hidden_states': outputs.hidden_states,
            'attentions': outputs.attentions
        }


def load_yelp_data_fast():
    """OPTIMIZATION: Load Yelp dataset for binary classification (same as 3D model)"""
    print("Loading Yelp Reviews dataset...")
    
    try:
        # Try to load the Yelp polarity dataset (binary classification)
        dataset = load_dataset('yelp_polarity')
        
        train_texts = dataset['train']['text']
        train_labels = dataset['train']['label']
        test_texts = dataset['test']['text']
        test_labels = dataset['test']['label']
        
        print(f"Loaded Yelp Polarity dataset")
        print(f"Labels: 0 (negative), 1 (positive)")
        
    except Exception as e:
        print(f"Failed to load yelp_polarity: {e}")
        print("Trying yelp_review_full dataset and converting to binary...")
        
        try:
            # Load full Yelp dataset and convert to binary
            dataset = load_dataset('yelp_review_full')
            
            train_texts = dataset['train']['text']
            train_labels_full = dataset['train']['label']
            test_texts = dataset['test']['text']
            test_labels_full = dataset['test']['label']
            
            # Convert 5-class to binary: 0,1,2 -> 0 (negative), 3,4 -> 1 (positive)
            train_labels = [0 if label < 3 else 1 for label in train_labels_full]
            test_labels = [0 if label < 3 else 1 for label in test_labels_full]
            
            print(f"Loaded Yelp Review Full dataset and converted to binary")
            print(f"Original labels 0,1,2 -> 0 (negative), labels 3,4 -> 1 (positive)")
            
        except Exception as e2:
            print(f"Failed to load yelp datasets: {e2}")
            print("Please install the datasets library and check your internet connection")
            raise e2

    # Use smaller test set for faster evaluation (same as 3D model)
    test_texts, _, test_labels, _ = train_test_split(
        test_texts, test_labels, train_size=0.05, stratify=test_labels, random_state=42
    )

    print(f"Training samples: {len(train_texts)}")
    print(f"Test samples: {len(test_texts)}")
    print(f"Label distribution - Train: {np.bincount(train_labels)}")
    print(f"Label distribution - Test: {np.bincount(test_labels)}")

    return train_texts, train_labels, test_texts, test_labels


def create_fast_data_loaders(train_texts, train_labels, test_texts, test_labels,
                            tokenizer, max_length=256, batch_size=16):
    """OPTIMIZATION: Larger batch size, shorter sequences (same as 3D model)"""
    train_dataset = OptimizedYelpDataset(train_texts, train_labels, tokenizer, max_length)
    test_dataset = OptimizedYelpDataset(test_texts, test_labels, tokenizer, max_length)

    # OPTIMIZATION: More workers for data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return train_loader, test_loader


def train_model_fast(model, train_loader, optimizer, scheduler, device, epoch, 
                    max_grad_norm=1.0, accumulation_steps=2):
    """OPTIMIZATION: Gradient accumulation for effective larger batch size (same as 3D model)"""
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    
    # OPTIMIZATION: Mixed precision training
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    for i, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)

        # OPTIMIZATION: Mixed precision forward pass
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs['loss'] / accumulation_steps  # Scale loss for accumulation
        else:
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs['loss'] / accumulation_steps

        # OPTIMIZATION: Gradient accumulation
        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (i + 1) % accumulation_steps == 0:
            if scaler is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
            
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps
        logits = outputs['logits']
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

        progress_bar.set_postfix({
            'loss': loss.item() * accumulation_steps,
            'lr': scheduler.get_last_lr()[0]
        })

    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, predictions)

    return avg_loss, accuracy


def main_longformer():
    """OPTIMIZATION: Fast training setup for Yelp dataset with Longformer"""
    print("\n=== OPTIMIZED Longformer for Yelp Reviews (Matching 3D RoBERTa) ===\n")

    # OPTIMIZATION: Same hyperparameters as 3D model for fair comparison
    MAX_LENGTH = 2048       # Same as 3D model
    BATCH_SIZE = 16        # Same as 3D model
    LEARNING_RATE = 5e-5   # Same as 3D model
    EPOCHS = 2            # Same as 3D model
    WARMUP_STEPS = 100     # Same as 3D model
    WEIGHT_DECAY = 0.01    # Same as 3D model
    MAX_GRAD_NORM = 1.0    # Same as 3D model
    ATTENTION_WINDOW = 512 # Longformer local attention window
    ACCUMULATION_STEPS = 4 # Same as 3D model

    print(f"LONGFORMER OPTIMIZATIONS (matching 3D RoBERTa):")
    print(f"  Max Length: {MAX_LENGTH}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Attention Window: {ATTENTION_WINDOW}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Mixed Precision: {'Yes' if torch.cuda.is_available() else 'No'}")
    print(f"  Gradient Accumulation: {ACCUMULATION_STEPS} steps")

    # Load Yelp data (same function as 3D model)
    train_texts, train_labels, test_texts, test_labels = load_yelp_data_fast()

    # Initialize Longformer tokenizer
    print("Loading Longformer tokenizer...")
    tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

    # Load base Longformer model
    print("Loading base Longformer model...")
    config = LongformerConfig.from_pretrained('allenai/longformer-base-4096')
    config.num_labels = 2  # Binary classification for Yelp sentiment
    
    # Create optimized Longformer model
    print(f"Creating optimized Longformer model for Yelp...")
    model = OptimizedLongformerForSequenceClassification(
        config,
        attention_window=ATTENTION_WINDOW
    )

    model.to(device)
    print(f"Model loaded with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")

    # Create data loaders (same as 3D model)
    train_loader, test_loader = create_fast_data_loaders(
        train_texts, train_labels, test_texts, test_labels,
        tokenizer, MAX_LENGTH, BATCH_SIZE
    )

    # Setup optimizer (same as 3D model)
    optimizer = AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        eps=1e-8
    )

    total_steps = len(train_loader) * EPOCHS // ACCUMULATION_STEPS
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=0.1,
        total_iters=WARMUP_STEPS
    )

    print(f"Starting Longformer training for {EPOCHS} epochs on Yelp dataset...")
    print(f"Effective batch size: {BATCH_SIZE * ACCUMULATION_STEPS}")

    # Training loop (same as 3D model)
    best_accuracy = 0

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")

        # Train
        start_time = time.time()
        train_loss, train_accuracy = train_model_fast(
            model, train_loader, optimizer, scheduler, device, epoch, 
            MAX_GRAD_NORM, ACCUMULATION_STEPS
        )
        train_time = time.time() - start_time

        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Training time: {train_time:.2f}s")

        # Quick evaluation
        model.eval()
        test_loss = 0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for batch in tqdm(test_loader, desc='Evaluating'):
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                labels = batch['labels'].to(device, non_blocking=True)

                if torch.cuda.is_available():
                    with torch.cuda.amp.autocast():
                        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                else:
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

                test_loss += outputs['loss'].item()
                preds = torch.argmax(outputs['logits'], dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)

        test_accuracy = test_correct / test_total
        print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {test_accuracy:.4f}")

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            print(f"New best accuracy: {best_accuracy:.4f}")

    print(f"\n🚀 LONGFORMER TRAINING ON YELP COMPLETED!")
    print(f"✅ Best Accuracy: {best_accuracy:.4f}")
    print(f"✅ Training Speed: ~{train_time/60:.1f} min/epoch")
    print(f"✅ Total Time: ~{(train_time * EPOCHS)/60:.1f} minutes")
    
    # Performance comparison
    print(f"\n📊 COMPARISON WITH 3D RoBERTa:")
    print(f"3D RoBERTa Accuracy: 94.0%")
    print(f"Longformer Accuracy: {best_accuracy*100:.1f}%")
    if best_accuracy > 0.94:
        print("🏆 Longformer outperformed 3D RoBERTa!")
    elif best_accuracy > 0.92:
        print("⚖ Longformer performance is competitive with 3D RoBERTa")
    else:
        print("📈 3D RoBERTa achieved higher accuracy")

    return model, best_accuracy


if __name__ == "__main__":
    # OPTIMIZATION: Clear cache before starting
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    model, accuracy = main_longformer()
