"""
GoEmotions Training and Evaluation Script
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import json
import logging
from tqdm import tqdm
from sklearn.metrics import (
    f1_score, accuracy_score, precision_score, recall_score,
    hamming_loss, jaccard_score, classification_report
)
from scipy import stats
from datasets import load_dataset
from transformers import (
    AutoModel, AutoTokenizer, AutoConfig,
    BertModel, BertTokenizer,
    RobertaModel, RobertaTokenizer,
    DebertaV2Model, DebertaV2Tokenizer,
    ElectraModel, ElectraTokenizer,
    DistilBertModel, DistilBertTokenizer,
    XLNetModel, XLNetTokenizer,
    LongformerModel, LongformerTokenizer, LongformerConfig, # Added LongformerConfig
    BigBirdModel, BigBirdTokenizer,
    ReformerModel, ReformerTokenizer, get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
import warnings
warnings.filterwarnings('ignore')

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ==================== Configuration ====================

@dataclass
class Config:
    """Training configuration"""
    # Model
    max_length: int = 128
    num_labels: int = 27
    dropout: float = 0.1
    threshold: float = 0.3  # Multi-label threshold
    
    # Training
    batch_size: int = 32
    learning_rate: float = 3e-5
    num_epochs: int = 6
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    gradient_accumulation_steps: int = 1
    max_grad_norm: float = 1.0
    
    # System
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed: int = 42
    checkpoint_dir: Path = Path('checkpoints_goemotions')
    
    def __post_init__(self):
        self.checkpoint_dir.mkdir(exist_ok=True)
        # Set seeds
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.seed)

# ==================== Dataset ====================

class GoEmotionsDataset(Dataset):
    """GoEmotions dataset wrapper"""
    
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.float)
        }

def load_goemotions_data():
    """Load GoEmotions dataset from HuggingFace"""
    logger.info("Loading GoEmotions dataset from HuggingFace...")
    
    # Load dataset
    dataset = load_dataset('go_emotions', 'simplified')
    
    # Get emotion labels
    emotion_labels = [
        'admiration', 'amusement', 'anger', 'annoyance', 'approval',
        'caring', 'confusion', 'curiosity', 'desire', 'disappointment',
        'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear',
        'gratitude', 'grief', 'joy', 'love', 'nervousness',
        'optimism', 'pride', 'realization', 'relief', 'remorse',
        'sadness', 'surprise'
    ]
    
    def process_labels(example):
        """Convert label list to multi-hot encoding"""
        label_vector = [0] * len(emotion_labels)
        for label_id in example['labels']:
            if label_id < len(emotion_labels):
                label_vector[label_id] = 1
        return label_vector
    
    # Process splits
    train_texts = dataset['train']['text']
    train_labels = [process_labels(example) for example in dataset['train']]
    
    val_texts = dataset['validation']['text']
    val_labels = [process_labels(example) for example in dataset['validation']]
    
    test_texts = dataset['test']['text']
    test_labels = [process_labels(example) for example in dataset['test']]
    
    logger.info(f"Dataset loaded - Train: {len(train_texts)}, Val: {len(val_texts)}, Test: {len(test_texts)}")
    
    return (train_texts, train_labels), (val_texts, val_labels), (test_texts, test_labels), emotion_labels

# ==================== Model Definitions ====================

class MultiLabelClassifier(nn.Module):
    """Base multi-label classifier"""
    
    def __init__(self, encoder, hidden_size, num_labels, dropout=0.1):
        super().__init__()
        self.encoder = encoder
        self.dropout = nn.Dropout(dropout)
        
        # Multi-label classification head
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels)
        )
        
    def forward(self, input_ids, attention_mask, **kwargs):
        # Get encoder outputs
        outputs = self.encoder(input_ids, attention_mask=attention_mask, **kwargs)
        
        # Pool (use CLS token or first token)
        if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
            pooled = outputs.pooler_output
        else:
            pooled = outputs.last_hidden_state[:, 0, :]
        
        # Classification
        logits = self.classifier(pooled)
        
        return logits

def create_model(model_name: str, config: Config):
    """Create a model based on name"""
    
    logger.info(f"Creating model: {model_name}")
    
    if model_name == "BERT-base-uncased":
        encoder = BertModel.from_pretrained('bert-base-uncased')
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        hidden_size = 768
        
    elif model_name == "RoBERTa-base":
        encoder = RobertaModel.from_pretrained('roberta-base')
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        hidden_size = 768
        
    elif model_name == "RoBERTa-large":
        encoder = RobertaModel.from_pretrained('roberta-large')
        tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
        hidden_size = 1024
        
    elif model_name == "DeBERTa-v3-base":
        encoder = DebertaV2Model.from_pretrained('microsoft/deberta-v3-base')
        tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-base')
        hidden_size = 768
        
    elif model_name == "ELECTRA-base":
        encoder = ElectraModel.from_pretrained('google/electra-base-discriminator')
        tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')
        hidden_size = 768
        
    elif model_name == "DistilBERT":
        encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
        tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        hidden_size = 768
        
    elif model_name == "XLNet-base":
        encoder = XLNetModel.from_pretrained('xlnet-base-cased')
        tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
        hidden_size = 768
        
    elif model_name == "Longformer-100M":
        # Custom configuration for a ~100M parameter Longformer trained from scratch
        longformer_config = LongformerConfig(
            num_hidden_layers=8,      # Reduced from 12 to 8
            hidden_size=768,
            num_attention_heads=12,
            intermediate_size=3072,
            attention_window=512,
            vocab_size=50265
        )
        encoder = LongformerModel(longformer_config) # Initialize from config, not pretrained
        tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
        hidden_size = 768

    elif model_name == "BigBird-base-4096":
        encoder = BigBirdModel.from_pretrained('google/bigbird-roberta-base')
        tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
        hidden_size = 768
        
    elif model_name == "Reformer":
        encoder = ReformerModel.from_pretrained('google/reformer-crime-and-punishment')
        tokenizer = ReformerTokenizer.from_pretrained('google/reformer-crime-and-punishment')
        hidden_size = encoder.config.hidden_size
        
    elif model_name == "FNet":
        encoder = AutoModel.from_pretrained('google/fnet-base')
        tokenizer = AutoTokenizer.from_pretrained('google/fnet-base')
        hidden_size = 768
        
    elif model_name in ["Performer", "Linformer"]:
        # These need custom implementations, using BERT as base
        logger.warning(f"{model_name} uses custom attention - using BERT as base")
        encoder = BertModel.from_pretrained('bert-base-uncased')
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        hidden_size = 768
        
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    model = MultiLabelClassifier(encoder, hidden_size, config.num_labels, config.dropout)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"Model created - Total params: {total_params/1e6:.1f}M, Trainable: {trainable_params/1e6:.1f}M")
    
    return model, tokenizer

# ==================== Training ====================

class Trainer:
    """Model trainer"""
    
    def __init__(self, model, tokenizer, config):
        self.model = model.to(config.device)
        self.tokenizer = tokenizer
        self.config = config
        self.best_val_f1 = 0
        
    def train_epoch(self, dataloader, optimizer, scheduler):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        pbar = tqdm(dataloader, desc="Training")
        for batch_idx, batch in enumerate(pbar):
            # Move to device
            input_ids = batch['input_ids'].to(self.config.device)
            attention_mask = batch['attention_mask'].to(self.config.device)
            labels = batch['labels'].to(self.config.device)
            
            # Forward pass
            logits = self.model(input_ids, attention_mask)
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            
            # Backward pass
            loss = loss / self.config.gradient_accumulation_steps
            loss.backward()
            
            if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            
            # Track metrics
            total_loss += loss.item() * self.config.gradient_accumulation_steps
            probs = torch.sigmoid(logits)
            preds = (probs > self.config.threshold).float()
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Calculate metrics
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        
        metrics = {
            'loss': total_loss / len(dataloader),
            'f1_macro': f1_score(all_labels, all_preds, average='macro', zero_division=0),
            'f1_micro': f1_score(all_labels, all_preds, average='micro', zero_division=0)
        }
        
        return metrics
    
    def evaluate(self, dataloader):
        """Evaluate model"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                labels = batch['labels'].to(self.config.device)
                
                logits = self.model(input_ids, attention_mask)
                loss = F.binary_cross_entropy_with_logits(logits, labels)
                
                total_loss += loss.item()
                probs = torch.sigmoid(logits)
                preds = (probs > self.config.threshold).float()
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        # Calculate metrics
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        all_probs = np.array(all_probs)
        
        metrics = {
            'loss': total_loss / len(dataloader),
            'f1_macro': f1_score(all_labels, all_preds, average='macro', zero_division=0),
            'f1_micro': f1_score(all_labels, all_preds, average='micro', zero_division=0),
            'f1_weighted': f1_score(all_labels, all_preds, average='weighted', zero_division=0),
            'precision_macro': precision_score(all_labels, all_preds, average='macro', zero_division=0),
            'precision_micro': precision_score(all_labels, all_preds, average='micro', zero_division=0),
            'recall_macro': recall_score(all_labels, all_preds, average='macro', zero_division=0),
            'recall_micro': recall_score(all_labels, all_preds, average='micro', zero_division=0),
            'hamming_loss': hamming_loss(all_labels, all_preds),
            'jaccard_score': jaccard_score(all_labels, all_preds, average='macro', zero_division=0),
            'subset_accuracy': accuracy_score(all_labels, all_preds)
        }
        
        return metrics, all_preds, all_labels, all_probs
    
    def train(self, train_loader, val_loader, model_name):
        """Full training loop"""
        # Setup optimizer and scheduler
        optimizer = AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        total_steps = len(train_loader) * self.config.num_epochs // self.config.gradient_accumulation_steps
        warmup_steps = int(total_steps * self.config.warmup_ratio)
        
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
        
        # Training loop
        train_history = []
        val_history = []
        
        for epoch in range(self.config.num_epochs):
            logger.info(f"\n{'='*50}")
            logger.info(f"Epoch {epoch+1}/{self.config.num_epochs}")
            
            # Train
            train_metrics = self.train_epoch(train_loader, optimizer, scheduler)
            train_history.append(train_metrics)
            
            # Validate
            val_metrics, _, _, _ = self.evaluate(val_loader)
            val_history.append(val_metrics)
            
            logger.info(f"Train Loss: {train_metrics['loss']:.4f}, F1-Macro: {train_metrics['f1_macro']:.4f}")
            logger.info(f"Val Loss: {val_metrics['loss']:.4f}, F1-Macro: {val_metrics['f1_macro']:.4f}")
            
            # Save best model
            if val_metrics['f1_macro'] > self.best_val_f1:
                self.best_val_f1 = val_metrics['f1_macro']
                
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_f1': self.best_val_f1,
                    'val_metrics': val_metrics,
                    'config': self.config
                }
                
                checkpoint_path = self.config.checkpoint_dir / f"{model_name.lower().replace(' ', '_')}_best.pt"
                torch.save(checkpoint, checkpoint_path)
                logger.info(f"Saved best model to {checkpoint_path}")
        
        return train_history, val_history

# ==================== TAN Model Loading ====================

def load_tan_results():
    """Load TAN results from checkpoint"""
    checkpoint_path = Path('goemotion_best_model.pt')
    
    if not checkpoint_path.exists():
        logger.warning(f"TAN checkpoint not found at {checkpoint_path}")
        return None
    
    logger.info(f"Loading TAN results from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    
    # Extract metrics
    metrics = checkpoint.get('metrics', {})
    if not metrics and 'best_val_f1' in checkpoint:
        metrics = {
            'f1_macro': checkpoint.get('best_val_f1', 0),
            'model': 'TAN'
        }
    
    return metrics

# ==================== Main Execution ====================

def main():
    """Main training and evaluation pipeline"""
    
    logger.info("="*80)
    logger.info("GoEmotions Multi-Label Classification - Training and Evaluation")
    logger.info("="*80)
    
    config = Config()
    
    # Load dataset
    (train_texts, train_labels), (val_texts, val_labels), (test_texts, test_labels), emotion_labels = load_goemotions_data()
    
    # Define models to train
    models_to_train = [
        "BERT-base-uncased",
        "RoBERTa-base",
        "RoBERTa-large",
        "DeBERTa-v3-base",
        "ELECTRA-base",
        "DistilBERT",
        "XLNet-base",
        "Longformer-100M", # <-- Replaced with the tuned 100M parameter version
        "BigBird-base-4096",
        "Performer",
        "Linformer"
    ]
    
    # Results storage
    all_results = {}
    
    # Load TAN results
    tan_results = load_tan_results()
    if tan_results:
        all_results['TAN'] = tan_results
        logger.info(f"TAN baseline F1-Macro: {tan_results.get('f1_macro', 0):.4f}")
    
    # Train each model
    for model_name in models_to_train:
        logger.info(f"\n{'='*80}")
        logger.info(f"Training {model_name}")
        logger.info(f"{'='*80}")
        
        try:
            # Check if already trained
            checkpoint_path = config.checkpoint_dir / f"{model_name.lower().replace(' ', '_')}_best.pt"
            
            if checkpoint_path.exists():
                logger.info(f"Found existing checkpoint for {model_name}, loading...")
                checkpoint = torch.load(checkpoint_path, map_location=config.device, weights_only=False)
                
                # Create model and load weights
                model, tokenizer = create_model(model_name, config)
                model.load_state_dict(checkpoint['model_state_dict'])
                model = model.to(config.device)
                
                # Evaluate on test set
                test_dataset = GoEmotionsDataset(test_texts, test_labels, tokenizer, config.max_length)
                test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
                
                trainer = Trainer(model, tokenizer, config)
                test_metrics, _, _, _ = trainer.evaluate(test_loader)
                
            else:
                logger.info(f"Training {model_name} from scratch...")
                
                # Create model
                model, tokenizer = create_model(model_name, config)
                
                # Create datasets
                train_dataset = GoEmotionsDataset(train_texts, train_labels, tokenizer, config.max_length)
                val_dataset = GoEmotionsDataset(val_texts, val_labels, tokenizer, config.max_length)
                test_dataset = GoEmotionsDataset(test_texts, test_labels, tokenizer, config.max_length)
                
                # Create dataloaders
                train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
                val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
                test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
                
                # Train
                trainer = Trainer(model, tokenizer, config)
                train_history, val_history = trainer.train(train_loader, val_loader, model_name)
                
                # Evaluate on test set
                test_metrics, _, _, _ = trainer.evaluate(test_loader)
            
            # Store results
            all_results[model_name] = test_metrics
            
            logger.info(f"\n{model_name} Test Results:")
            logger.info(f"  F1-Macro: {test_metrics['f1_macro']:.4f}")
            logger.info(f"  F1-Micro: {test_metrics['f1_micro']:.4f}")
            logger.info(f"  Hamming Loss: {test_metrics['hamming_loss']:.4f}")
            logger.info(f"  Subset Accuracy: {test_metrics['subset_accuracy']:.4f}")
            
        except Exception as e:
            logger.error(f"Failed to train/evaluate {model_name}: {e}")
            all_results[model_name] = {'error': str(e)}
    
    # Generate final report
    logger.info("\n" + "="*80)
    logger.info("FINAL RESULTS COMPARISON")
    logger.info("="*80)
    
    # Sort by F1-Macro
    sorted_results = sorted(
        [(name, res) for name, res in all_results.items() if 'f1_macro' in res],
        key=lambda x: x[1]['f1_macro'],
        reverse=True
    )
    
    logger.info("\nModel Rankings (by F1-Macro):")
    for rank, (model_name, metrics) in enumerate(sorted_results, 1):
        logger.info(f"{rank}. {model_name:20s} - F1-Macro: {metrics['f1_macro']:.4f}, "
                      f"F1-Micro: {metrics.get('f1_micro', 0):.4f}, "
                      f"Hamming: {metrics.get('hamming_loss', 0):.4f}")
    
    # Statistical significance testing
    if len(sorted_results) > 1 and 'TAN' in all_results:
        logger.info("\nStatistical Comparison with TAN:")
        tan_f1 = all_results['TAN'].get('f1_macro', 0)
        
        for model_name, metrics in sorted_results:
            if model_name != 'TAN':
                model_f1 = metrics['f1_macro']
                diff = tan_f1 - model_f1
                logger.info(f"  {model_name}: {'+' if diff > 0 else ''}{diff:.4f} "
                              f"({'better' if diff > 0 else 'worse'} than TAN)")
    
    # Save results
    results_path = Path('goemotions_final_results.json')
    with open(results_path, 'w') as f:
        json.dump(all_results, f, indent=2, default=str)
    
    logger.info(f"\nResults saved to {results_path}")

if __name__ == "__main__":
    main()