# -*- coding: utf-8 -*-
"""Untitled49.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/126NKQmbjdlw4CatIcUSXcZbDW24p_vjq
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Vanilla 2D RoBERTa Student-Teacher Knowledge Distillation
Standard attention mechanisms for comparison with higher-order 3D attention.

"""

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Force CPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class VanillaRobertaStudentModel(RobertaForSequenceClassification):
    """
    Vanilla RoBERTa student model with standard 2D attention mechanisms.
    This is essentially a smaller version of the standard RoBERTa model.
    """

    def __init__(self, config, teacher_model=None):
        super().__init__(config)

        if teacher_model is not None:
            # Copy embeddings and pooler from teacher
            self.roberta.embeddings = teacher_model.roberta.embeddings
            self.roberta.pooler = teacher_model.roberta.pooler

            # Copy first config.num_hidden_layers from teacher
            for i in range(config.num_hidden_layers):
                if i < len(teacher_model.roberta.encoder.layer):
                    self.roberta.encoder.layer[i] = teacher_model.roberta.encoder.layer[i]

            # Copy classifier
            self.classifier = teacher_model.classifier

        # Initialize or update model
        self.post_init()


def load_pretrained_teacher_model(num_layers=3, num_classes=4):
    """Load and adapt pretrained RoBERTa model as teacher."""
    print("Loading pretrained RoBERTa model...")

    # Load pretrained model
    model = RobertaForSequenceClassification.from_pretrained("roberta-base")

    # Create new config with fewer layers but same dimensions
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = num_layers
    config.hidden_dropout_prob = 0.3
    config.attention_probs_dropout_prob = 0.3
    config.num_labels = num_classes

    # Create new model with fewer layers
    new_model = RobertaForSequenceClassification(config)

    # Copy weights for embeddings and first layers
    new_model.roberta.embeddings = model.roberta.embeddings
    for i in range(num_layers):
        new_model.roberta.encoder.layer[i] = model.roberta.encoder.layer[i]

    # Copy pooler and classifier (note: classifier will be resized for num_classes)
    new_model.roberta.pooler = model.roberta.pooler

    # Initialize classifier with correct output size
    new_model.classifier = RobertaClassificationHead(config)

    print(f"Verifying model output size: {new_model.classifier.out_proj.out_features} classes")
    assert new_model.classifier.out_proj.out_features == num_classes, "Model output size doesn't match num_classes"

    print(f"Created teacher model with {num_layers} layers from pretrained RoBERTa with {num_classes} output classes")
    return new_model, config


def create_vanilla_student_model(config, teacher_model, student_layers=2):
    """Create vanilla student model with fewer layers than teacher."""
    print(f"Creating vanilla 2D student model with {student_layers} layers from teacher...")

    # Create student config with fewer layers
    student_config = RobertaConfig.from_pretrained("roberta-base")
    student_config.num_hidden_layers = student_layers
    student_config.hidden_dropout_prob = 0.3
    student_config.attention_probs_dropout_prob = 0.3
    student_config.num_labels = config.num_labels

    # Create student model
    student_model = VanillaRobertaStudentModel(student_config, teacher_model)

    print(f"Successfully created vanilla 2D student model with {student_layers} layers")
    return student_model


def load_and_prepare_qa_classification_data(batch_size=16, max_length=128, max_samples=5000, num_classes=4):
    """Load and prepare Question Answering Classification dataset for training.

    """
    print(f"Loading and preparing Question Classification dataset (limited to {max_samples} samples, {num_classes} classes)...")

    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    # We can use the SQuAD dataset as our base but convert it into a classification task
    # based on question type
    dataset = load_dataset("squad")

    # Function to classify questions into 4 categories based on question words and structure
    def classify_question(example):
        question = example["question"].lower()

        # Class 0: Factoid questions (who, what, when, where)
        if any(question.startswith(word) for word in ["who", "what", "when", "where"]):
            example["qa_class"] = 0

        # Class 1: Explanation questions (why, how)
        elif any(question.startswith(word) for word in ["why", "how"]):
            example["qa_class"] = 1

        # Class 2: Yes/No questions (is, are, was, were, do, does, did, can, could, will, would, etc.)
        elif any(question.startswith(word) for word in ["is", "are", "was", "were", "do", "does",
                                                      "did", "can", "could", "will", "would", "has",
                                                      "have", "had", "should", "shall"]):
            example["qa_class"] = 2

        # Class 3: Comparison questions (which, better, difference, compare, etc.)
        elif any(word in question for word in ["which", "better", "difference", "compare", "versus", "vs"]):
            example["qa_class"] = 3

        # Default to factoid (class 0) if no pattern matches
        else:
            example["qa_class"] = 0

        return example

    # Apply classification to the dataset
    train_data = dataset["train"].map(classify_question)
    val_data = dataset["validation"].map(classify_question)

    # Filter to ensure we have all desired classes
    train_data = train_data.filter(lambda example: example["qa_class"] < num_classes)
    val_data = val_data.filter(lambda example: example["qa_class"] < num_classes)

    # Limit dataset size for memory constraints
    if max_samples and max_samples < len(train_data):
        # Create balanced subsets to maintain class distribution
        train_subset = []
        samples_per_class = max_samples // num_classes

        for class_id in range(num_classes):
            class_samples = train_data.filter(lambda example: example["qa_class"] == class_id)
            class_samples = class_samples.shuffle(seed=42).select(range(min(samples_per_class, len(class_samples))))
            train_subset.append(class_samples)

        # Combine all class samples
        train_subset = concatenate_datasets(train_subset)

        # For validation set
        val_subset = []
        val_samples_per_class = (max_samples // 2) // num_classes

        for class_id in range(num_classes):
            class_samples = val_data.filter(lambda example: example["qa_class"] == class_id)
            class_samples = class_samples.shuffle(seed=42).select(range(min(val_samples_per_class, len(class_samples))))
            val_subset.append(class_samples)

        val_subset = concatenate_datasets(val_subset)

        print(f"Limited dataset to {len(train_subset)} train and {len(val_subset)} validation samples")
    else:
        train_subset = train_data
        val_subset = val_data

    def tokenize_function(examples):
        # For QA classification, we just use the question text
        return tokenizer(
            examples["question"],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )

    # Tokenize data
    tokenized_train = train_subset.map(tokenize_function, batched=True)
    tokenized_val = val_subset.map(tokenize_function, batched=True)

    # Format for PyTorch - use the qa_class field
    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])
    tokenized_val.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])

    # Create dataloaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(tokenized_val, batch_size=batch_size)

    print(f"Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_val)}")

    # Print class distribution
    train_labels = [example["qa_class"].item() for example in tokenized_train]
    val_labels = [example["qa_class"].item() for example in tokenized_val]

    # Map numeric labels to class names for better understanding
    label_map = {0: "Factoid", 1: "Explanation", 2: "Yes/No", 3: "Comparison"}

    print("Train class distribution:")
    for class_id in range(num_classes):
        count = train_labels.count(class_id)
        class_name = label_map.get(class_id, f"Class {class_id}")
        print(f"  {class_name}: {count} samples ({count/len(train_labels)*100:.2f}%)")

    print("Validation class distribution:")
    for class_id in range(num_classes):
        count = val_labels.count(class_id)
        class_name = label_map.get(class_id, f"Class {class_id}")
        print(f"  {class_name}: {count} samples ({count/len(val_labels)*100:.2f}%)")

    return tokenizer, train_loader, val_loader, label_map


def finetune_teacher_model(model, train_loader, val_loader, epochs=5, num_classes=4, label_map=None):
    """Finetune the pretrained RoBERTa teacher model with early stopping for QA classification."""
    print(f"Finetuning vanilla 2D teacher model for up to {epochs} epochs with early stopping...")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.03,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)

    # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)

    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)

    # Loss function
    loss_fn = nn.CrossEntropyLoss()

    # Training tracking
    best_val_acc = 0.0
    best_epoch = 0
    patience = 2
    no_improvement = 0
    train_stats = []
    val_stats = []

    # Use the provided label_map or default to QA classification labels
    if label_map is None:
        label_map = {0: "Factoid", 1: "Explanation", 2: "Yes/No", 3: "Comparison"}

    # Define valid classes
    valid_classes = list(range(num_classes))

    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Vanilla 2D Teacher Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Use qa_class instead of coarse_label
            labels = batch["qa_class"].to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Calculate loss
            loss = loss_fn(logits, labels)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Track metrics
            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": scheduler.get_last_lr()[0]
            })

        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        # Reset per-class metrics for validation
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                # Use qa_class
                labels = batch["qa_class"].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                # Calculate loss
                loss = loss_fn(logits, labels)
                val_loss += loss.item() * labels.size(0)

                preds = torch.argmax(logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in valid_classes:
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total

        # Print stats
        print(f"Vanilla 2D Teacher Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in valid_classes:
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save stats
        train_stats.append({
            "epoch": epoch + 1,
            "loss": epoch_train_loss,
            "accuracy": epoch_train_acc
        })

        val_stat_entry = {
            "epoch": epoch + 1,
            "loss": epoch_val_loss,
            "accuracy": epoch_val_acc
        }

        for i in valid_classes:
            if class_total[i] > 0:
                class_name = label_map.get(i, f"Class {i}")
                val_stat_entry[f"{class_name}_acc"] = class_correct[i] / class_total[i]
            else:
                val_stat_entry[f"class_{i}_acc"] = 0.0

        val_stats.append(val_stat_entry)

        # Save best model and check for early stopping
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_epoch = epoch
            no_improvement = 0
            torch.save(model.state_dict(), "results/models/vanilla_2d_qa_teacher_best.pt")
            print(f"  New best vanilla 2D teacher model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    model.load_state_dict(torch.load("results/models/vanilla_2d_qa_teacher_best.pt"))
    print(f"Loaded best vanilla 2D teacher model from epoch {best_epoch + 1} with val accuracy: {best_val_acc:.4f}")

    # Save training stats
    pd.DataFrame(train_stats).to_csv("results/metrics/vanilla_2d_qa_teacher_train_metrics.csv", index=False)
    pd.DataFrame(val_stats).to_csv("results/metrics/vanilla_2d_qa_teacher_val_metrics.csv", index=False)

    return model, best_val_acc


def train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=5,
        learning_rate=2e-5,
        weight_decay=0.03,
        alpha=0.5,
        temperature=3.0,
        num_classes=4,
        label_map=None
):
    """Train vanilla 2D RoBERTa student model with knowledge distillation and early stopping."""
    print(f"Training vanilla 2D student model with distillation for up to {epochs} epochs with early stopping...")

    # Prepare optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)  # 10% warmup

    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)

    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")

    # Training tracking
    best_val_acc = 0.0
    best_epoch = 0
    patience = 2
    no_improvement = 0
    train_stats = []
    val_stats = []

    # Define class names for better interpretability
    if label_map is None:
        label_map = {0: "Factoid", 1: "Explanation", 2: "Yes/No", 3: "Comparison"}

    # Define valid classes based on num_classes
    valid_classes = list(range(num_classes))

    # Track per-class metrics
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for epoch in range(epochs):
        # Training
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_ce_loss = 0
        train_kl_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Vanilla 2D Student Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using qa_class for QA dataset
            labels = batch["qa_class"].to(device)

            # Forward pass - student
            optimizer.zero_grad()
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Forward pass - teacher
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Cross-entropy loss
            ce_loss = ce_loss_fn(student_logits, labels)
            train_ce_loss += ce_loss.item() * labels.size(0)

            # Distillation loss
            student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
            teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
            kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
            train_kl_loss += kl_loss.item() * labels.size(0)

            # Combined loss
            loss = (1 - alpha) * ce_loss + alpha * kl_loss
            train_loss += loss.item() * labels.size(0)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Track accuracy
            preds = torch.argmax(student_logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": scheduler.get_last_lr()[0]
            })

        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_ce_loss = train_ce_loss / train_total
        epoch_train_kl_loss = train_kl_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        student_model.eval()
        val_loss = 0
        val_ce_loss = 0
        val_kl_loss = 0
        val_correct = 0
        val_total = 0

        # Reset per-class metrics for validation
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                # Using qa_class for QA dataset
                labels = batch["qa_class"].to(device)

                # Student predictions
                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits

                # Teacher predictions
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

                # Cross-entropy loss
                ce_loss = ce_loss_fn(student_logits, labels)
                val_ce_loss += ce_loss.item() * labels.size(0)

                # Distillation loss
                student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
                val_kl_loss += kl_loss.item() * labels.size(0)

                # Combined loss
                loss = (1 - alpha) * ce_loss + alpha * kl_loss
                val_loss += loss.item() * labels.size(0)

                # Accuracy
                preds = torch.argmax(student_logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in valid_classes:
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_ce_loss = val_ce_loss / val_total
        epoch_val_kl_loss = val_kl_loss / val_total
        epoch_val_acc = val_correct / val_total

        # Print stats
        print(f"Vanilla 2D Student Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in valid_classes:
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save stats
        train_stats.append({
            "epoch": epoch + 1,
            "loss": epoch_train_loss,
            "ce_loss": epoch_train_ce_loss,
            "kl_loss": epoch_train_kl_loss,
            "accuracy": epoch_train_acc
        })

        # Add per-class metrics to validation stats
        val_stat_entry = {
            "epoch": epoch + 1,
            "loss": epoch_val_loss,
            "ce_loss": epoch_val_ce_loss,
            "kl_loss": epoch_val_kl_loss,
            "accuracy": epoch_val_acc
        }

        # Add per-class accuracy
        for i in valid_classes:
            if class_total[i] > 0:
                class_name = label_map.get(i, f"Class {i}")
                val_stat_entry[f"{class_name}_acc"] = class_correct[i] / class_total[i]
            else:
                val_stat_entry[f"class_{i}_acc"] = 0.0

        val_stats.append(val_stat_entry)

        # Save best model and check for early stopping
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_epoch = epoch
            no_improvement = 0
            torch.save(student_model.state_dict(), "results/models/vanilla_2d_qa_student_best.pt")
            print(f"  New best vanilla 2D student model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    student_model.load_state_dict(torch.load("results/models/vanilla_2d_qa_student_best.pt"))
    print(f"Loaded best vanilla 2D student model from epoch {best_epoch + 1} with val accuracy: {best_val_acc:.4f}")

    # Save final model too
    torch.save(student_model.state_dict(), "results/models/vanilla_2d_qa_student_final.pt")

    # Save training stats
    pd.DataFrame(train_stats).to_csv("results/metrics/vanilla_2d_qa_student_train_metrics.csv", index=False)
    pd.DataFrame(val_stats).to_csv("results/metrics/vanilla_2d_qa_student_val_metrics.csv", index=False)

    return student_model, best_val_acc


def main():
    """Main execution function for SQuAD QA 4-class classification with vanilla 2D attention."""
    print("\n=== Vanilla 2D RoBERTa Student-Teacher Knowledge Distillation for QA Classification ===\n")

    # Parameters - matching the 3D higher-order version exactly
    NUM_LAYERS = 3  # Same as 3D version
    BATCH_SIZE = 8  # Same as 3D version
    MAX_LENGTH = 128  # Same as 3D version
    MAX_SAMPLES = 3000  # Same as 3D version
    TEACHER_EPOCHS = 5  # Same as 3D version
    STUDENT_EPOCHS = 5  # Same as 3D version
    DISTILLATION_ALPHA = 0.5  # Same as 3D version
    TEMPERATURE = 3.0  # Same as 3D version
    NUM_CLASSES = 4  # QA classification with 4 classes
    STUDENT_LAYERS = 2  # Fewer layers than teacher for distillation

    # Save configuration
    config_dict = {
        "model_type": "vanilla_2d_roberta",
        "model_source": "roberta-base (pretrained)",
        "num_layers": NUM_LAYERS,
        "student_layers": STUDENT_LAYERS,
        "batch_size": BATCH_SIZE,
        "max_length": MAX_LENGTH,
        "max_samples": MAX_SAMPLES,
        "teacher_epochs": TEACHER_EPOCHS,
        "student_epochs": STUDENT_EPOCHS,
        "distillation_alpha": DISTILLATION_ALPHA,
        "temperature": TEMPERATURE,
        "attention_type": "vanilla_2d",
        "dropout": 0.3,
        "weight_decay": 0.03,
        "learning_rate": 2e-5,
        "dataset": "QA Classification (4-class: Factoid, Explanation, Yes/No, Comparison)",
        "num_classes": NUM_CLASSES
    }

    # Save config
    with open("results/metrics/vanilla_2d_roberta_qa_config.json", "w") as f:
        import json
        json.dump(config_dict, f, indent=2)

    # 1. Load QA Classification data
    tokenizer, train_loader, val_loader, label_map = load_and_prepare_qa_classification_data(
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH,
        max_samples=MAX_SAMPLES,
        num_classes=NUM_CLASSES
    )

    # 2. Load pretrained RoBERTa model and adapt it
    teacher_model, roberta_config = load_pretrained_teacher_model(
        num_layers=NUM_LAYERS,
        num_classes=NUM_CLASSES
    )
    teacher_model.to(device)

    # 3. Finetune teacher model
    print("\n=== Step 1: Finetuning Vanilla 2D RoBERTa Teacher Model ===\n")
    teacher_model, teacher_acc = finetune_teacher_model(
        teacher_model,
        train_loader,
        val_loader,
        epochs=TEACHER_EPOCHS,
        num_classes=NUM_CLASSES,
        label_map=label_map
    )

    # 4. Create and train vanilla 2D student model with standard attention
    print("\n=== Step 2: Training Vanilla 2D RoBERTa Student Model with Distillation ===\n")

    # Create vanilla 2D student model with fewer layers than teacher
    student_model = create_vanilla_student_model(
        roberta_config,
        teacher_model,
        student_layers=STUDENT_LAYERS
    )
    student_model.to(device)

    # Train with distillation
    student_model, best_val_acc = train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=STUDENT_EPOCHS,
        learning_rate=2e-5,
        weight_decay=0.03,
        alpha=DISTILLATION_ALPHA,
        temperature=TEMPERATURE,
        num_classes=NUM_CLASSES,
        label_map=label_map
    )

    print("\n=== Training Complete ===")
    print(f"Best vanilla 2D student model validation accuracy: {best_val_acc:.4f}")
    print(f"Models saved to: results/models/")
    print(f"Training metrics saved to: results/metrics/")

    # Additional analysis - compare teacher vs student on validation set
    print("\n=== Comparing Vanilla 2D Teacher vs Student Performance ===")
    teacher_model.eval()
    student_model.eval()

    teacher_correct = 0
    student_correct = 0
    total = 0

    # Define valid classes
    valid_classes = list(range(NUM_CLASSES))

    # For per-class metrics
    teacher_class_correct = [0] * NUM_CLASSES
    student_class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating Models"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using qa_class for QA dataset
            labels = batch["qa_class"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)
            teacher_correct += (teacher_preds == labels).sum().item()

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)
            student_correct += (student_preds == labels).sum().item()

            # Calculate per-class accuracy
            for i in valid_classes:
                class_mask = (labels == i)
                class_count = class_mask.sum().item()

                if class_count > 0:
                    teacher_class_correct[i] += ((teacher_preds == labels) & class_mask).sum().item()
                    student_class_correct[i] += ((student_preds == labels) & class_mask).sum().item()
                    class_total[i] += class_count

            total += labels.size(0)

    teacher_acc = teacher_correct / total
    student_acc = student_correct / total

    print(f"Final Vanilla 2D RoBERTa Teacher Accuracy: {teacher_acc:.4f}")
    print(f"Final Vanilla 2D RoBERTa Student Accuracy: {student_acc:.4f}")
    print(f"Difference: {(student_acc - teacher_acc):.4f}")

    # Print per-class accuracies
    print("\nPer-Class Accuracies:")
    for i in valid_classes:
        if class_total[i] > 0:
            teacher_class_acc = teacher_class_correct[i] / class_total[i]
            student_class_acc = student_class_correct[i] / class_total[i]
            class_name = label_map.get(i, f"Class {i}")
            print(f"  {class_name}: Teacher: {teacher_class_acc:.4f}, Student: {student_class_acc:.4f}, "
                  f"Diff: {(student_class_acc - teacher_class_acc):.4f}")

    # Save comparison
    with open("results/metrics/vanilla_2d_roberta_qa_comparison.txt", "w") as f:
        f.write("Vanilla 2D RoBERTa Teacher vs Student Model Comparison (QA 4-class)\n")
        f.write("================================================================\n\n")
        f.write(f"Dataset: QA Classification ({NUM_CLASSES}-class: {', '.join([label_map[i] for i in valid_classes])})\n")
        f.write(f"Teacher Model Accuracy: {teacher_acc:.4f}\n")
        f.write(f"Student Model Accuracy: {student_acc:.4f}\n")
        f.write(f"Difference: {(student_acc - teacher_acc):.4f}\n\n")

        f.write("Per-Class Accuracies:\n")
        for i in valid_classes:
            if class_total[i] > 0:
                teacher_class_acc = teacher_class_correct[i] / class_total[i]
                student_class_acc = student_class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                f.write(f"  {class_name}: Teacher: {teacher_class_acc:.4f}, Student: {student_class_acc:.4f}, "
                      f"Diff: {(student_class_acc - teacher_class_acc):.4f}\n")

        f.write("\nTraining Configuration:\n")
        for key, value in config_dict.items():
            f.write(f"  {key}: {value}\n")

    # Create a confusion matrix for both models
    print("\n=== Generating Confusion Matrices ===")

    # Initialize confusion matrices
    teacher_confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)
    student_confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Computing Confusion Matrices"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using qa_class for QA dataset
            labels = batch["qa_class"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)

            # Update confusion matrices
            for true, pred in zip(labels.cpu().numpy(), teacher_preds.cpu().numpy()):
                teacher_confusion[true, pred] += 1

            for true, pred in zip(labels.cpu().numpy(), student_preds.cpu().numpy()):
                student_confusion[true, pred] += 1

    # Print confusion matrices
    print("\nVanilla 2D Teacher Confusion Matrix:")
    print(teacher_confusion)
    print("\nVanilla 2D Student Confusion Matrix:")
    print(student_confusion)

    # Save confusion matrices
    np.savetxt("results/metrics/vanilla_2d_qa_teacher_confusion_matrix.csv", teacher_confusion, delimiter=',', fmt='%d')
    np.savetxt("results/metrics/vanilla_2d_qa_student_confusion_matrix.csv", student_confusion, delimiter=',', fmt='%d')

    # Also save confusion matrix with human-readable class names
    with open("results/metrics/vanilla_2d_qa_confusion_matrix_class_names.txt", "w") as f:
        f.write("Class indices to names mapping (QA 4-class):\n")
        for i in valid_classes:
            f.write(f"{i}: {label_map[i]}\n")

    # Model size comparison
    print("\n=== Model Size Comparison ===")

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    teacher_params = count_parameters(teacher_model)
    student_params = count_parameters(student_model)

    print(f"Teacher model parameters: {teacher_params:,}")
    print(f"Student model parameters: {student_params:,}")
    print(f"Parameter reduction: {((teacher_params - student_params) / teacher_params * 100):.2f}%")

    # Performance by question length analysis (same as 3D version for comparison)
    print("\n=== Performance Analysis by Question Length ===")

    # Analyze performance by question length
    performance_by_length = {
        "short": {"total": 0, "teacher": 0, "student": 0},
        "medium": {"total": 0, "teacher": 0, "student": 0},
        "long": {"total": 0, "teacher": 0, "student": 0},
    }

    improved_examples = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Analyzing Performance by Length"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            # Calculate sequence lengths (ignoring padding)
            seq_lengths = attention_mask.sum(dim=1).cpu().numpy()

            # Get predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)

            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)

            # Analyze by example
            for i in range(len(labels)):
                seq_len = seq_lengths[i]
                teacher_correct = (teacher_preds[i] == labels[i]).item()
                student_correct = (student_preds[i] == labels[i]).item()

                # Categorize by length
                length_category = "short" if seq_len < 20 else "medium" if seq_len < 50 else "long"
                performance_by_length[length_category]["total"] += 1

                if teacher_correct:
                    performance_by_length[length_category]["teacher"] += 1

                if student_correct:
                    performance_by_length[length_category]["student"] += 1

                # Identify examples where student performs better than teacher
                if student_correct and not teacher_correct:
                    improved_examples.append({
                        "true_label": labels[i].item(),
                        "teacher_pred": teacher_preds[i].item(),
                        "length": seq_len
                    })

    # Print findings
    print("\nPerformance by question length:")
    for length, stats in performance_by_length.items():
        if stats["total"] > 0:
            teacher_acc = stats["teacher"] / stats["total"]
            student_acc = stats["student"] / stats["total"]
            print(f"  {length.capitalize()} questions ({stats['total']} examples):")
            print(f"    Teacher: {teacher_acc:.4f}, Student: {student_acc:.4f}, Diff: {student_acc - teacher_acc:.4f}")

    print(f"\nVanilla 2D student model improved on {len(improved_examples)} examples")

    # Analyze improvement by class
    improved_by_class = {}
    for example in improved_examples:
        class_name = label_map[example["true_label"]]
        if class_name not in improved_by_class:
            improved_by_class[class_name] = 0
        improved_by_class[class_name] += 1

    print("\nImprovements by question type:")
    for class_name, count in improved_by_class.items():
        print(f"  {class_name}: {count} examples")

    # Save model comparison stats
    model_stats = {
        "teacher_params": teacher_params,
        "student_params": student_params,
        "param_reduction_pct": (teacher_params - student_params) / teacher_params * 100,
        "teacher_layers": NUM_LAYERS,
        "student_layers": STUDENT_LAYERS,
        "teacher_accuracy": teacher_acc,
        "student_accuracy": student_acc,
        "accuracy_difference": student_acc - teacher_acc,
        "attention_type": "vanilla_2d",
        "improved_examples": len(improved_examples),
        "performance_by_length": performance_by_length
    }

    with open("results/metrics/vanilla_2d_qa_model_stats.json", "w") as f:
        import json
        json.dump(model_stats, f, indent=2)

    # Save detailed analysis
    with open("results/metrics/vanilla_2d_qa_analysis.txt", "w") as f:
        f.write("Analysis of Vanilla 2D Attention Model Performance\n")
        f.write("===============================================\n\n")

        f.write("Performance by question length:\n")
        for length, stats in performance_by_length.items():
            if stats["total"] > 0:
                teacher_acc = stats["teacher"] / stats["total"]
                student_acc = stats["student"] / stats["total"]
                f.write(f"  {length.capitalize()} questions ({stats['total']} examples):\n")
                f.write(f"    Teacher: {teacher_acc:.4f}, Student: {student_acc:.4f}, Diff: {student_acc - teacher_acc:.4f}\n")

        f.write(f"\nVanilla 2D student model improved on {len(improved_examples)} examples\n")
        f.write("\nImprovements by question type:\n")
        for class_name, count in improved_by_class.items():
            f.write(f"  {class_name}: {count} examples\n")

        # f.write("\nVanilla 2D Model Characteristics:\n")
        # f.write("  1. Standard Attention: Uses traditional Q·K^T attention mechanism\n")
        # f.write("  2. Computational Efficiency: Lower memory and computational requirements\n")
        # f.write("  3. Established Architecture: Well-understood transformer attention patterns\n")
        # f.write("  4. Knowledge Distillation: Benefits from teacher model guidance\n")

        # f.write("\nComparison Baseline:\n")
        # f.write("  This vanilla 2D implementation serves as a baseline for comparison with\n")
        # f.write("  higher-order 3D tensor product attention mechanisms. Key differences:\n")
        # f.write("  - 2D: Standard Q·K^T attention (O(n²d) complexity)\n")
        # f.write("  - 3D: Q·K1·K2 tensor product attention (O(n³d) complexity)\n")
        # f.write("  Both use identical training procedures and hyperparameters for fair comparison.\n")

    print("\nAll vanilla 2D results saved successfully!")
    # print("\n=== Summary ===")
    # print(f"Vanilla 2D Teacher ({NUM_LAYERS} layers): {teacher_acc:.4f} accuracy")
    # print(f"Vanilla 2D Student ({STUDENT_LAYERS} layers): {student_acc:.4f} accuracy")
    # print(f"Student achieved {((student_acc / teacher_acc) * 100):.2f}% of teacher performance")
    # print(f"With {((teacher_params - student_params) / teacher_params * 100):.2f}% fewer parameters")

    # # Comparison with 3D version
    # print("\n=== Comparison Framework (2D vs 3D) ===")
    # print("This vanilla 2D implementation provides a baseline for comparison with:")
    # print("1. Higher-order 3D tensor product attention (from your original code)")
    # print("2. Standard transformer attention mechanisms")
    # print("\nKey differences:")
    # print("- 2D: Standard Q·K^T attention (this implementation)")
    # print("- 3D: Q·K1·K2 tensor product attention (your higher-order version)")
    # print("\nIdentical configuration:")
    # print(f"- Dataset: SQuAD QA {NUM_CLASSES}-class classification")
    # print(f"- Architecture: {NUM_LAYERS}-layer teacher, {STUDENT_LAYERS}-layer student")
    # print(f"- Training: {TEACHER_EPOCHS} teacher epochs, {STUDENT_EPOCHS} student epochs")
    # print(f"- Distillation: α={DISTILLATION_ALPHA}, T={TEMPERATURE}")
    # print(f"- Data: {MAX_SAMPLES} samples, batch size {BATCH_SIZE}")


if __name__ == "__main__":
    main()