
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Vanilla RoBERTa Student-Teacher Knowledge Distillation
Standard attention mechanisms for comparison with higher-order attention.

"""

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Force CPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class VanillaRobertaStudentModel(RobertaForSequenceClassification):
    """
    Vanilla RoBERTa student model with standard attention mechanisms.
    This is essentially a smaller version of the standard RoBERTa model.
    """

    def __init__(self, config, teacher_model=None):
        super().__init__(config)

        if teacher_model is not None:
            # Copy embeddings and pooler from teacher
            self.roberta.embeddings = teacher_model.roberta.embeddings
            self.roberta.pooler = teacher_model.roberta.pooler

            # Copy first config.num_hidden_layers from teacher
            for i in range(config.num_hidden_layers):
                if i < len(teacher_model.roberta.encoder.layer):
                    self.roberta.encoder.layer[i] = teacher_model.roberta.encoder.layer[i]

            # Copy classifier
            self.classifier = teacher_model.classifier

        # Initialize or update model
        self.post_init()


def load_pretrained_teacher_model(num_layers=6, num_classes=5):
    """Load and adapt pretrained RoBERTa model as teacher."""
    print("Loading pretrained RoBERTa model...")

    # Load pretrained model
    model = RobertaForSequenceClassification.from_pretrained("roberta-base")

    # Create new config with fewer layers but same dimensions
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = num_layers
    config.hidden_dropout_prob = 0.3
    config.attention_probs_dropout_prob = 0.3
    config.num_labels = num_classes

    # Create new model with fewer layers
    new_model = RobertaForSequenceClassification(config)

    # Copy weights for embeddings and first layers
    new_model.roberta.embeddings = model.roberta.embeddings
    for i in range(num_layers):
        new_model.roberta.encoder.layer[i] = model.roberta.encoder.layer[i]

    # Copy pooler and classifier (note: classifier will be resized for num_classes)
    new_model.roberta.pooler = model.roberta.pooler

    # Initialize classifier with correct output size
    new_model.classifier = RobertaClassificationHead(config)

    print(f"Verifying model output size: {new_model.classifier.out_proj.out_features} classes")
    assert new_model.classifier.out_proj.out_features == num_classes, "Model output size doesn't match num_classes"

    print(f"Created teacher model with {num_layers} layers from pretrained RoBERTa with {num_classes} output classes")
    return new_model, config


def create_vanilla_student_model(config, teacher_model, student_layers=2):
    """Create vanilla student model with fewer layers than teacher."""
    print(f"Creating vanilla student model with {student_layers} layers from teacher...")

    # Create student config with fewer layers
    student_config = RobertaConfig.from_pretrained("roberta-base")
    student_config.num_hidden_layers = student_layers
    student_config.hidden_dropout_prob = 0.3
    student_config.attention_probs_dropout_prob = 0.3
    student_config.num_labels = config.num_labels

    # Create student model
    student_model = VanillaRobertaStudentModel(student_config, teacher_model)

    print(f"Successfully created vanilla student model with {student_layers} layers")
    return student_model


def load_and_prepare_trec_data(batch_size=16, max_length=128, max_samples=5000, num_classes=5):
    """Load and prepare TREC dataset for training with memory constraints.


    Args:
        batch_size: Batch size for training
        max_length: Maximum sequence length for truncation
        max_samples: Maximum number of samples to use (for memory constraints)
        num_classes: Number of classes to use (5 by default)
    """
    print(f"Loading and preparing TREC dataset (limited to {max_samples} samples, {num_classes} classes)...")

    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    dataset = load_dataset("trec")  # TREC dataset with 6 classes



    # Select only the first num_classes
    train_data = dataset["train"].filter(lambda example: example["coarse_label"] < num_classes)
    test_data = dataset["test"].filter(lambda example: example["coarse_label"] < num_classes)

    # Limit dataset size for memory constraints
    if max_samples and max_samples < len(train_data):
        # Create balanced subsets to maintain class distribution
        train_subset = []

        # Get samples per class
        samples_per_class = max_samples // num_classes

        for class_id in range(num_classes):
            class_samples = train_data.filter(lambda example: example["coarse_label"] == class_id)
            class_samples = class_samples.shuffle(seed=42).select(range(min(samples_per_class, len(class_samples))))
            train_subset.append(class_samples)

        # Combine all class samples
        train_subset = concatenate_datasets(train_subset)

        # For test set
        test_subset = []
        test_samples_per_class = (max_samples // 2) // num_classes

        for class_id in range(num_classes):
            class_samples = test_data.filter(lambda example: example["coarse_label"] == class_id)
            class_samples = class_samples.shuffle(seed=42).select(range(min(test_samples_per_class, len(class_samples))))
            test_subset.append(class_samples)

        test_subset = concatenate_datasets(test_subset)

        print(f"Limited dataset to {len(train_subset)} train and {len(test_subset)} test samples")
    else:
        train_subset = train_data
        test_subset = test_data

    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )

    # Tokenize data
    tokenized_train = train_subset.map(tokenize_function, batched=True)
    tokenized_val = test_subset.map(tokenize_function, batched=True)

    # Format for PyTorch - note the change from "label" to "coarse_label"
    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "coarse_label"])
    tokenized_val.set_format(type="torch", columns=["input_ids", "attention_mask", "coarse_label"])

    # Create dataloaders with smaller batch size
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(tokenized_val, batch_size=batch_size)

    print(f"Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_val)}")

    # Print class distribution
    train_labels = [example["coarse_label"].item() for example in tokenized_train]
    val_labels = [example["coarse_label"].item() for example in tokenized_val]

    # Map numeric labels to class names for better understanding
    label_map = {0: "ABBR", 1: "DESC", 2: "ENTY", 3: "HUM", 4: "LOC", 5: "NUM"}

    print("Train class distribution:")
    for class_id in range(num_classes):
        count = train_labels.count(class_id)
        class_name = label_map.get(class_id, f"Class {class_id}")
        print(f"  {class_name}: {count} samples ({count/len(train_labels)*100:.2f}%)")

    print("Validation class distribution:")
    for class_id in range(num_classes):
        count = val_labels.count(class_id)
        class_name = label_map.get(class_id, f"Class {class_id}")
        print(f"  {class_name}: {count} samples ({count/len(val_labels)*100:.2f}%)")

    return tokenizer, train_loader, val_loader


def finetune_teacher_model(model, train_loader, val_loader, epochs=1, num_classes=5):
    """Finetune the pretrained RoBERTa teacher model with early stopping."""
    print(f"Finetuning teacher model for up to {epochs} epochs with early stopping...")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.03,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)

    # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)

    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)

    # Loss function
    loss_fn = nn.CrossEntropyLoss()

    # Training tracking
    best_val_acc = 0.0
    best_epoch = 0
    patience = 2
    no_improvement = 0
    train_stats = []
    val_stats = []

    # Define class names for better interpretability
    label_map = {0: "ABBR", 1: "DESC", 2: "ENTY", 3: "HUM", 4: "LOC", 5: "NUM"}

    # Define valid classes based on num_classes
    valid_classes = list(range(num_classes))

    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Teacher Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using coarse_label for TREC dataset
            labels = batch["coarse_label"].to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Calculate loss
            loss = loss_fn(logits, labels)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Track metrics
            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": scheduler.get_last_lr()[0]
            })

        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        # Reset per-class metrics for validation
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                # Using coarse_label for TREC dataset
                labels = batch["coarse_label"].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                # Calculate loss
                loss = loss_fn(logits, labels)
                val_loss += loss.item() * labels.size(0)

                preds = torch.argmax(logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in valid_classes:
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total

        # Print stats
        print(f"Teacher Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in valid_classes:
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save stats
        train_stats.append({
            "epoch": epoch + 1,
            "loss": epoch_train_loss,
            "accuracy": epoch_train_acc
        })

        val_stat_entry = {
            "epoch": epoch + 1,
            "loss": epoch_val_loss,
            "accuracy": epoch_val_acc
        }

        for i in valid_classes:
            if class_total[i] > 0:
                class_name = label_map.get(i, f"Class {i}")
                val_stat_entry[f"{class_name}_acc"] = class_correct[i] / class_total[i]
            else:
                val_stat_entry[f"class_{i}_acc"] = 0.0

        val_stats.append(val_stat_entry)

        # Save best model and check for early stopping
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_epoch = epoch
            no_improvement = 0
            torch.save(model.state_dict(), "results/models/vanilla_teacher_best.pt")
            print(f"  New best teacher model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    model.load_state_dict(torch.load("results/models/vanilla_teacher_best.pt"))
    print(f"Loaded best teacher model from epoch {best_epoch + 1} with val accuracy: {best_val_acc:.4f}")

    # Save training stats
    pd.DataFrame(train_stats).to_csv("results/metrics/vanilla_teacher_train_metrics.csv", index=False)
    pd.DataFrame(val_stats).to_csv("results/metrics/vanilla_teacher_val_metrics.csv", index=False)

    return model, best_val_acc


def train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=6,
        learning_rate=2e-5,
        weight_decay=0.03,
        alpha=0.5,
        temperature=3.0,
        num_classes=5
):
    """Train vanilla RoBERTa student model with knowledge distillation and early stopping."""
    print(f"Training vanilla student model with distillation for up to {epochs} epochs with early stopping...")

    # Prepare optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)  # 10% warmup

    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)

    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")

    # Training tracking
    best_val_acc = 0.0
    best_epoch = 0
    patience = 2
    no_improvement = 0
    train_stats = []
    val_stats = []

    # Define class names for better interpretability
    label_map = {0: "ABBR", 1: "DESC", 2: "ENTY", 3: "HUM", 4: "LOC", 5: "NUM"}

    # Define valid classes based on num_classes
    valid_classes = list(range(num_classes))

    # Track per-class metrics
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for epoch in range(epochs):
        # Training
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_ce_loss = 0
        train_kl_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Vanilla Student Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using coarse_label for TREC dataset
            labels = batch["coarse_label"].to(device)

            # Forward pass - student
            optimizer.zero_grad()
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Forward pass - teacher
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Cross-entropy loss
            ce_loss = ce_loss_fn(student_logits, labels)
            train_ce_loss += ce_loss.item() * labels.size(0)

            # Distillation loss
            student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
            teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
            kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
            train_kl_loss += kl_loss.item() * labels.size(0)

            # Combined loss
            loss = (1 - alpha) * ce_loss + alpha * kl_loss
            train_loss += loss.item() * labels.size(0)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Track accuracy
            preds = torch.argmax(student_logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": scheduler.get_last_lr()[0]
            })

        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_ce_loss = train_ce_loss / train_total
        epoch_train_kl_loss = train_kl_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        student_model.eval()
        val_loss = 0
        val_ce_loss = 0
        val_kl_loss = 0
        val_correct = 0
        val_total = 0

        # Reset per-class metrics for validation
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                # Using coarse_label for TREC dataset
                labels = batch["coarse_label"].to(device)

                # Student predictions
                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits

                # Teacher predictions
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

                # Cross-entropy loss
                ce_loss = ce_loss_fn(student_logits, labels)
                val_ce_loss += ce_loss.item() * labels.size(0)

                # Distillation loss
                student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
                val_kl_loss += kl_loss.item() * labels.size(0)

                # Combined loss
                loss = (1 - alpha) * ce_loss + alpha * kl_loss
                val_loss += loss.item() * labels.size(0)

                # Accuracy
                preds = torch.argmax(student_logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in valid_classes:
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_ce_loss = val_ce_loss / val_total
        epoch_val_kl_loss = val_kl_loss / val_total
        epoch_val_acc = val_correct / val_total

        # Print stats
        print(f"Vanilla Student Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in valid_classes:
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save stats
        train_stats.append({
            "epoch": epoch + 1,
            "loss": epoch_train_loss,
            "ce_loss": epoch_train_ce_loss,
            "kl_loss": epoch_train_kl_loss,
            "accuracy": epoch_train_acc
        })

        # Add per-class metrics to validation stats
        val_stat_entry = {
            "epoch": epoch + 1,
            "loss": epoch_val_loss,
            "ce_loss": epoch_val_ce_loss,
            "kl_loss": epoch_val_kl_loss,
            "accuracy": epoch_val_acc
        }

        # Add per-class accuracy
        for i in valid_classes:
            if class_total[i] > 0:
                class_name = label_map.get(i, f"Class {i}")
                val_stat_entry[f"{class_name}_acc"] = class_correct[i] / class_total[i]
            else:
                val_stat_entry[f"class_{i}_acc"] = 0.0

        val_stats.append(val_stat_entry)

        # Save best model and check for early stopping
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_epoch = epoch
            no_improvement = 0
            torch.save(student_model.state_dict(), "results/models/vanilla_student_best.pt")
            print(f"  New best vanilla student model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    student_model.load_state_dict(torch.load("results/models/vanilla_student_best.pt"))
    print(f"Loaded best vanilla student model from epoch {best_epoch + 1} with val accuracy: {best_val_acc:.4f}")

    # Save final model too
    torch.save(student_model.state_dict(), "results/models/vanilla_student_final.pt")

    # Save training stats
    pd.DataFrame(train_stats).to_csv("results/metrics/vanilla_student_train_metrics.csv", index=False)
    pd.DataFrame(val_stats).to_csv("results/metrics/vanilla_student_val_metrics.csv", index=False)

    return student_model, best_val_acc


def main():
    """Main execution function."""
    print("\n=== Vanilla RoBERTa Student-Teacher Knowledge Distillation ===\n")

    # Parameters - adjusted for memory constraints and comparison
    NUM_TEACHER_LAYERS = 3  # Same as higher-order version
    NUM_STUDENT_LAYERS = 2  # Fewer layers than teacher for distillation
    BATCH_SIZE = 8  # Same as higher-order version
    MAX_LENGTH = 128  # Same as higher-order version
    MAX_SAMPLES = 3000  # Same as higher-order version
    TEACHER_EPOCHS = 5  # Same as higher-order version
    STUDENT_EPOCHS = 5  # Same as higher-order version
    DISTILLATION_ALPHA = 0.5  # Same as higher-order version
    TEMPERATURE = 3.0  # Same as higher-order version
    NUM_CLASSES = 5  # Using 5 classes to match higher-order version: ABBR, DESC, ENTY, HUM, LOC

    # Save configuration
    config_dict = {
        "model_type": "vanilla_roberta",
        "model_source": "roberta-base (pretrained)",
        "num_teacher_layers": NUM_TEACHER_LAYERS,
        "num_student_layers": NUM_STUDENT_LAYERS,
        "batch_size": BATCH_SIZE,
        "max_length": MAX_LENGTH,
        "max_samples": MAX_SAMPLES,
        "teacher_epochs": TEACHER_EPOCHS,
        "student_epochs": STUDENT_EPOCHS,
        "distillation_alpha": DISTILLATION_ALPHA,
        "temperature": TEMPERATURE,
        "attention_type": "vanilla",
        "dropout": 0.3,
        "weight_decay": 0.03,
        "learning_rate": 2e-5,
        "dataset": "TREC (5-class: ABBR, DESC, ENTY, HUM, LOC)",
        "num_classes": NUM_CLASSES
    }

    # Save config
    with open("results/metrics/vanilla_roberta_config.json", "w") as f:
        import json
        json.dump(config_dict, f, indent=2)

    # 1. Load TREC data with 3 classes (with reduced sample size for memory constraints)
    tokenizer, train_loader, val_loader = load_and_prepare_trec_data(
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH,
        max_samples=MAX_SAMPLES,
        num_classes=NUM_CLASSES
    )

    # 2. Load pretrained RoBERTa model and adapt it to our size and number of classes
    teacher_model, roberta_config = load_pretrained_teacher_model(
        num_layers=NUM_TEACHER_LAYERS,
        num_classes=NUM_CLASSES
    )
    teacher_model.to(device)

    # 3. Finetune teacher model
    print("\n=== Step 1: Finetuning Vanilla RoBERTa Teacher Model ===\n")
    teacher_model, teacher_acc = finetune_teacher_model(
        teacher_model,
        train_loader,
        val_loader,
        epochs=TEACHER_EPOCHS,
        num_classes=NUM_CLASSES
    )

    # 4. Create and train vanilla student model with standard attention
    print("\n=== Step 2: Training Vanilla RoBERTa Student Model with Distillation ===\n")

    # Create vanilla student model with fewer layers than teacher
    student_model = create_vanilla_student_model(
        roberta_config,
        teacher_model,
        student_layers=NUM_STUDENT_LAYERS
    )
    student_model.to(device)

    # Train with distillation
    student_model, best_val_acc = train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=STUDENT_EPOCHS,
        learning_rate=2e-5,
        weight_decay=0.03,
        alpha=DISTILLATION_ALPHA,
        temperature=TEMPERATURE,
        num_classes=NUM_CLASSES
    )

    print("\n=== Training Complete ===")
    print(f"Best vanilla student model validation accuracy: {best_val_acc:.4f}")
    print(f"Models saved to: results/models/")
    print(f"Training metrics saved to: results/metrics/")

    # Additional analysis - compare teacher vs student on validation set
    print("\n=== Comparing Vanilla Teacher vs Student Performance ===")
    teacher_model.eval()
    student_model.eval()

    teacher_correct = 0
    student_correct = 0
    total = 0

    # Define valid classes
    valid_classes = list(range(NUM_CLASSES))

    # For per-class metrics
    teacher_class_correct = [0] * NUM_CLASSES
    student_class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES

    # Define class names for better interpretability
    label_map = {0: "ABBR", 1: "DESC", 2: "ENTY", 3: "HUM", 4: "LOC", 5: "NUM"}

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating Models"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using coarse_label for TREC dataset
            labels = batch["coarse_label"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)
            teacher_correct += (teacher_preds == labels).sum().item()

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)
            student_correct += (student_preds == labels).sum().item()

            # Calculate per-class accuracy
            for i in valid_classes:
                class_mask = (labels == i)
                class_count = class_mask.sum().item()

                if class_count > 0:
                    teacher_class_correct[i] += ((teacher_preds == labels) & class_mask).sum().item()
                    student_class_correct[i] += ((student_preds == labels) & class_mask).sum().item()
                    class_total[i] += class_count

            total += labels.size(0)

    teacher_acc = teacher_correct / total
    student_acc = student_correct / total

    print(f"Final Vanilla RoBERTa Teacher Accuracy: {teacher_acc:.4f}")
    print(f"Final Vanilla RoBERTa Student Accuracy: {student_acc:.4f}")
    print(f"Difference: {(student_acc - teacher_acc):.4f}")

    # Print per-class accuracies
    print("\nPer-Class Accuracies:")
    for i in valid_classes:
        if class_total[i] > 0:
            teacher_class_acc = teacher_class_correct[i] / class_total[i]
            student_class_acc = student_class_correct[i] / class_total[i]
            class_name = label_map.get(i, f"Class {i}")
            print(f"  {class_name}: Teacher: {teacher_class_acc:.4f}, Student: {student_class_acc:.4f}, "
                  f"Diff: {(student_class_acc - teacher_class_acc):.4f}")

    # Save comparison
    with open("results/metrics/vanilla_roberta_comparison.txt", "w") as f:
        f.write("Vanilla RoBERTa Teacher vs Student Model Comparison\n")
        f.write("=================================================\n\n")
        f.write(f"Dataset: TREC ({NUM_CLASSES}-class: {', '.join([label_map[i] for i in valid_classes])})\n")
        f.write(f"Teacher Model Accuracy: {teacher_acc:.4f}\n")
        f.write(f"Student Model Accuracy: {student_acc:.4f}\n")
        f.write(f"Difference: {(student_acc - teacher_acc):.4f}\n\n")

        f.write("Per-Class Accuracies:\n")
        for i in valid_classes:
            if class_total[i] > 0:
                teacher_class_acc = teacher_class_correct[i] / class_total[i]
                student_class_acc = student_class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                f.write(f"  {class_name}: Teacher: {teacher_class_acc:.4f}, Student: {student_class_acc:.4f}, "
                      f"Diff: {(student_class_acc - teacher_class_acc):.4f}\n")

        f.write("\nTraining Configuration:\n")
        for key, value in config_dict.items():
            f.write(f"  {key}: {value}\n")

    # Create a confusion matrix for both models
    print("\n=== Generating Confusion Matrices ===")

    # Initialize confusion matrices
    teacher_confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)
    student_confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Computing Confusion Matrices"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using coarse_label for TREC dataset
            labels = batch["coarse_label"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)

            # Update confusion matrices
            for true, pred in zip(labels.cpu().numpy(), teacher_preds.cpu().numpy()):
                teacher_confusion[true, pred] += 1

            for true, pred in zip(labels.cpu().numpy(), student_preds.cpu().numpy()):
                student_confusion[true, pred] += 1

    # Print confusion matrices
    print("\nVanilla Teacher Confusion Matrix:")
    print(teacher_confusion)
    print("\nVanilla Student Confusion Matrix:")
    print(student_confusion)

    # Save confusion matrices
    np.savetxt("results/metrics/vanilla_teacher_confusion_matrix.csv", teacher_confusion, delimiter=',', fmt='%d')
    np.savetxt("results/metrics/vanilla_student_confusion_matrix.csv", student_confusion, delimiter=',', fmt='%d')

    # Also save confusion matrix with human-readable class names
    with open("results/metrics/vanilla_confusion_matrix_class_names.txt", "w") as f:
        f.write("Class indices to names mapping:\n")
        for i in valid_classes:
            f.write(f"{i}: {label_map[i]}\n")

    # Model size comparison
    print("\n=== Model Size Comparison ===")

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    teacher_params = count_parameters(teacher_model)
    student_params = count_parameters(student_model)

    print(f"Teacher model parameters: {teacher_params:,}")
    print(f"Student model parameters: {student_params:,}")
    print(f"Parameter reduction: {((teacher_params - student_params) / teacher_params * 100):.2f}%")

    # Save model comparison stats
    model_stats = {
        "teacher_params": teacher_params,
        "student_params": student_params,
        "param_reduction_pct": (teacher_params - student_params) / teacher_params * 100,
        "teacher_layers": NUM_TEACHER_LAYERS,
        "student_layers": NUM_STUDENT_LAYERS,
        "teacher_accuracy": teacher_acc,
        "student_accuracy": student_acc,
        "accuracy_difference": student_acc - teacher_acc
    }

    with open("results/metrics/vanilla_model_stats.json", "w") as f:
        import json
        json.dump(model_stats, f, indent=2)

    print("\nAll vanilla results saved successfully!")
    print("\n=== Summary ===")
    print(f"Vanilla Teacher ({NUM_TEACHER_LAYERS} layers): {teacher_acc:.4f} accuracy")
    print(f"Vanilla Student ({NUM_STUDENT_LAYERS} layers): {student_acc:.4f} accuracy")
    print(f"Student achieved {((student_acc / teacher_acc) * 100):.2f}% of teacher performance")
    print(f"With {((teacher_params - student_params) / teacher_params * 100):.2f}% fewer parameters")


if __name__ == "__main__":
    main()