

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm


import datasets
datasets.disable_caching()

from datasets import load_dataset, concatenate_datasets

# Fixed imports for compatible versions
from transformers import (
    RobertaTokenizer,
    RobertaModel,
    RobertaConfig,
    RobertaForSequenceClassification
)

from transformers.models.roberta.modeling_roberta import (
    RobertaAttention, RobertaSelfAttention, RobertaLayer,
    RobertaEncoder, RobertaClassificationHead
)

from torch.utils.data import DataLoader
from torch.optim import AdamW
import logging

# Suppress tokenizer warnings
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)

# Force CPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class VanillaStudentRobertaForSequenceClassification(RobertaForSequenceClassification):
    """Vanilla student model - just a smaller version of RoBERTa."""

    def __init__(self, config, original_model=None):
        super().__init__(config)

        if original_model is not None:
            # Copy weights from teacher model
            self.roberta.embeddings = original_model.roberta.embeddings

            # Copy first few layers from teacher
            num_student_layers = config.num_hidden_layers
            for i in range(num_student_layers):
                self.roberta.encoder.layer[i] = original_model.roberta.encoder.layer[i]

            self.roberta.pooler = original_model.roberta.pooler
            self.classifier = original_model.classifier

        self.post_init()


def load_pretrained_teacher_model(num_layers=6, num_classes=3):
    """Load and adapt pretrained RoBERTa model as teacher."""
    print("Loading pretrained RoBERTa model...")

    model = RobertaForSequenceClassification.from_pretrained("roberta-base")
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = num_layers
    config.hidden_dropout_prob = 0.1
    config.attention_probs_dropout_prob = 0.1
    config.num_labels = num_classes

    new_model = RobertaForSequenceClassification(config)
    new_model.roberta.embeddings = model.roberta.embeddings

    for i in range(num_layers):
        new_model.roberta.encoder.layer[i] = model.roberta.encoder.layer[i]

    new_model.roberta.pooler = model.roberta.pooler
    new_model.classifier = RobertaClassificationHead(config)

    print(f"Created teacher model with {num_layers} layers and {num_classes} output classes")
    return new_model, config


def create_vanilla_student_model(config, teacher_model, num_student_layers=2):
    """Create vanilla student model - just smaller RoBERTa."""
    print(f"Creating vanilla student model with {num_student_layers} layers from teacher...")

    # Create smaller config for student
    student_config = RobertaConfig.from_pretrained("roberta-base")
    student_config.num_hidden_layers = num_student_layers
    student_config.hidden_dropout_prob = config.hidden_dropout_prob
    student_config.attention_probs_dropout_prob = config.attention_probs_dropout_prob
    student_config.num_labels = config.num_labels

    student_model = VanillaStudentRobertaForSequenceClassification(
        student_config,
        original_model=teacher_model
    )

    print(f"Successfully created vanilla student model with {num_student_layers} layers")
    return student_model


def load_mnli_data(batch_size=16, max_length=128, split="train", num_proc=4):
    """Load MNLI data with optimized preprocessing."""
    from datasets import load_dataset
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    print("Loading MNLI dataset...")
    dataset = load_dataset("glue", "mnli")

    # Take smaller subsets for faster training (optional - remove if you want full dataset)
    print("Preparing data subsets...")
    train_dataset = dataset["train"].shuffle(seed=42).select(range(min(50000, len(dataset["train"]))))
    val_dataset = dataset["validation_matched"].shuffle(seed=42).select(range(min(5000, len(dataset["validation_matched"]))))

    def preprocess(examples):
        """Optimized preprocessing function with proper handling."""
        # Tokenize with explicit truncation and no warnings
        encoded = tokenizer(
            examples["premise"],
            examples["hypothesis"],
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors=None,  # Return lists, not tensors
            verbose=False  # Suppress warnings
        )
        return encoded

    print("Tokenizing training data...")
    tokenized_train = train_dataset.map(
        preprocess,
        batched=True,
        num_proc=num_proc,
        remove_columns=train_dataset.column_names,
        desc="Tokenizing train"
    )

    print("Tokenizing validation data...")
    tokenized_val = val_dataset.map(
        preprocess,
        batched=True,
        num_proc=num_proc,
        remove_columns=val_dataset.column_names,
        desc="Tokenizing validation"
    )

    # Add labels back
    tokenized_train = tokenized_train.add_column("qa_class", train_dataset["label"])
    tokenized_val = tokenized_val.add_column("qa_class", val_dataset["label"])

    # Set format for PyTorch
    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])
    tokenized_val.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])

    # Create data loaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(tokenized_val, batch_size=batch_size, num_workers=2)

    label_map = {
        0: "entailment",
        1: "neutral",
        2: "contradiction"
    }

    print(f"MNLI Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_val)}")
    return tokenizer, train_loader, val_loader, label_map


def finetune_teacher_model(model, train_loader, val_loader, epochs=3, num_classes=3, label_map=None):
    """Finetune teacher model."""
    print(f"Finetuning teacher model for up to {epochs} epochs...")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)

    # Simple scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    loss_fn = nn.CrossEntropyLoss()
    best_val_acc = 0.0
    patience = 2
    no_improvement = 0

    if label_map is None:
        label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Teacher Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = loss_fn(logits, labels)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": optimizer.param_groups[0]['lr']
            })

        scheduler.step()
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["qa_class"].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                loss = loss_fn(logits, labels)
                val_loss += loss.item() * labels.size(0)

                preds = torch.argmax(logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in range(num_classes):
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total

        print(f"Teacher Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in range(num_classes):
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            no_improvement = 0
            torch.save(model.state_dict(), "results/models/vanilla_teacher_best.pt")
            print(f"  New best teacher model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    model.load_state_dict(torch.load("results/models/vanilla_teacher_best.pt"))
    print(f"Loaded best teacher model with val accuracy: {best_val_acc:.4f}")

    return model, best_val_acc


def train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=3,
        learning_rate=5e-6,
        weight_decay=0.01,
        alpha=0.3,  # Weight on distillation loss
        temperature=2.0,
        num_classes=3
):
    """Train vanilla student model with knowledge distillation."""
    print(f"Training vanilla student model with distillation for up to {epochs} epochs...")

    # Optimizer setup
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Simple scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8)

    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")

    # Training tracking
    best_val_acc = 0.0
    patience = 2
    no_improvement = 0
    label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}

    for epoch in range(epochs):
        # Training
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Vanilla Student Epoch {epoch + 1}/{epochs}")
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            optimizer.zero_grad()

            # Forward pass - student
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Forward pass - teacher
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Cross-entropy loss
            ce_loss = ce_loss_fn(student_logits, labels)

            # Distillation loss
            student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
            teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
            kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)

            # Combined loss
            loss = (1 - alpha) * ce_loss + alpha * kl_loss

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()

            # Track metrics
            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(student_logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            progress_bar.set_postfix({
                "loss": loss.item(),
                "ce_loss": ce_loss.item(),
                "kl_loss": kl_loss.item(),
                "acc": train_correct / train_total,
                "lr": optimizer.param_groups[0]['lr']
            })

        scheduler.step()
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        student_model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["qa_class"].to(device)

                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits

                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

                # Losses
                ce_loss = ce_loss_fn(student_logits, labels)
                student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
                loss = (1 - alpha) * ce_loss + alpha * kl_loss

                val_loss += loss.item() * labels.size(0)

                # Accuracy
                preds = torch.argmax(student_logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in range(num_classes):
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total

        print(f"Vanilla Student Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in range(num_classes):
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            no_improvement = 0
            torch.save(student_model.state_dict(), "results/models/vanilla_student_best.pt")
            print(f"  New best vanilla student model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    if os.path.exists("results/models/vanilla_student_best.pt"):
        student_model.load_state_dict(torch.load("results/models/vanilla_student_best.pt"))
        print(f"Loaded best vanilla student model with val accuracy: {best_val_acc:.4f}")

    return student_model, best_val_acc


def main_vanilla_comparison():
    """Main execution for vanilla student-teacher comparison."""
    print("\n=== Vanilla RoBERTa Student-Teacher Knowledge Distillation ===\n")

    # Same parameters as the higher-order version for fair comparison
    NUM_TEACHER_LAYERS = 3
    NUM_STUDENT_LAYERS = 2  # Smaller student
    BATCH_SIZE = 8
    MAX_LENGTH = 64
    TEACHER_EPOCHS = 3
    STUDENT_EPOCHS = 3
    DISTILLATION_ALPHA = 0.3
    TEMPERATURE = 2.0
    NUM_CLASSES = 3

    print(f"Configuration:")
    print(f"  Teacher Layers: {NUM_TEACHER_LAYERS}")
    print(f"  Student Layers: {NUM_STUDENT_LAYERS}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Max Length: {MAX_LENGTH}")
    print(f"  Distillation Alpha: {DISTILLATION_ALPHA}")
    print(f"  Temperature: {TEMPERATURE}")

    # Load data with optimizations
    tokenizer, train_loader, val_loader, label_map = load_mnli_data(
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH,
        num_proc=4  # Use multiprocessing for faster data loading
    )

    # Load teacher model
    teacher_model, roberta_config = load_pretrained_teacher_model(
        num_layers=NUM_TEACHER_LAYERS,
        num_classes=NUM_CLASSES
    )
    teacher_model.to(device)

    # Train teacher
    print("\n=== Step 1: Finetuning Vanilla RoBERTa Teacher Model ===\n")
    teacher_model, teacher_acc = finetune_teacher_model(
        teacher_model,
        train_loader,
        val_loader,
        epochs=TEACHER_EPOCHS,
        num_classes=NUM_CLASSES,
        label_map=label_map
    )

    # Create and train vanilla student
    print("\n=== Step 2: Training Vanilla RoBERTa Student Model with Distillation ===\n")
    student_model = create_vanilla_student_model(
        roberta_config,
        teacher_model,
        num_student_layers=NUM_STUDENT_LAYERS
    )
    student_model.to(device)

    student_model, best_val_acc = train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=STUDENT_EPOCHS,
        learning_rate=5e-6,
        weight_decay=0.01,
        alpha=DISTILLATION_ALPHA,
        temperature=TEMPERATURE,
        num_classes=NUM_CLASSES
    )

    print("\n=== Training Complete ===")
    print(f"Best vanilla student model validation accuracy: {best_val_acc:.4f}")
    print(f"Models saved to: results/models/")

    # Final comparison
    print("\n=== Comparing Vanilla Teacher vs Vanilla Student Performance ===")
    teacher_model.eval()
    student_model.eval()

    teacher_correct = 0
    student_correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Final Evaluation"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)
            teacher_correct += (teacher_preds == labels).sum().item()

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)
            student_correct += (student_preds == labels).sum().item()

            total += labels.size(0)

    teacher_final_acc = teacher_correct / total
    student_final_acc = student_correct / total

    print(f"Final Vanilla Teacher Accuracy: {teacher_final_acc:.4f}")
    print(f"Final Vanilla Student Accuracy: {student_final_acc:.4f}")
    print(f"Performance Gap (Student - Teacher): {(student_final_acc - teacher_final_acc):.4f}")

    # Model size comparison
    teacher_params = sum(p.numel() for p in teacher_model.parameters())
    student_params = sum(p.numel() for p in student_model.parameters())
    compression_ratio = teacher_params / student_params

    print(f"\nModel Size Comparison:")
    print(f"Teacher Parameters: {teacher_params:,}")
    print(f"Student Parameters: {student_params:,}")
    print(f"Compression Ratio: {compression_ratio:.2f}x")

    return {
        'teacher_acc': teacher_final_acc,
        'student_acc': student_final_acc,
        'teacher_params': teacher_params,
        'student_params': student_params,
        'compression_ratio': compression_ratio
    }


if __name__ == "__main__":
    results = main_vanilla_comparison()