

# #!/usr/bin/env python3
# # -*- coding: utf-8 -*-


"""
Higher-Order Attention RoBERTa with Knowledge Distillation
Using tensor product interactions for higher-order attention.
Adapted for TREC 3-class classification.
"""

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from transformers.models.roberta.modeling_roberta import RobertaAttention, RobertaSelfAttention, RobertaLayer, \
    RobertaEncoder

from transformers.models.roberta.modeling_roberta import RobertaClassificationHead


from torch.utils.data import DataLoader
from torch.optim import AdamW

# Force CPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

#  random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class HigherOrderAttention(nn.Module):
    """
    Higher-Order Attention module that captures tensor product interactions between keys and values.
    Based on factorization machine concepts.
    """

    def __init__(self, config, order=2):
        super().__init__()
        self.order = order
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Projections for query and keys
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key1 = nn.Linear(config.hidden_size, self.all_head_size)
        self.key2 = nn.Linear(config.hidden_size, self.all_head_size)

        # Projections for values
        self.value1 = nn.Linear(config.hidden_size, self.all_head_size)
        self.value2 = nn.Linear(config.hidden_size, self.all_head_size)

        # Output projection
        self.output_projection = nn.Linear(self.all_head_size, config.hidden_size)

        # Dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

        self.initialize_weights()

    def initialize_weights(self):
        """Initialize weights properly."""
        for module in [self.query, self.key1, self.key2, self.value1, self.value2, self.output_projection]:
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def transpose_for_scores(self, x):
        """Reshape for multi-head attention."""
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_shape)
        return x.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, head_dim)

    def forward(self, hidden_states, attention_mask=None):
        """
        Forward pass with higher-order attention computation using vectorized operations.
        hidden_states: (batch_size, seq_len, hidden_size)
        attention_mask: (batch_size, 1, 1, seq_len)
        """
        batch_size, seq_length = hidden_states.size(0), hidden_states.size(1)

        # Project inputs to queries and keys, values
        query_layer = self.query(hidden_states)  # (batch_size, seq_len, all_head_size)
        key1_layer = self.key1(hidden_states)  # (batch_size, seq_len, all_head_size)
        key2_layer = self.key2(hidden_states)  # (batch_size, seq_len, all_head_size)
        value1_layer = self.value1(hidden_states)  # (batch_size, seq_len, all_head_size)
        value2_layer = self.value2(hidden_states)  # (batch_size, seq_len, all_head_size)

        # Reshape to multi-head format
        # (batch_size, num_heads, seq_len, head_dim)
        query_layer = self.transpose_for_scores(query_layer)
        key1_layer = self.transpose_for_scores(key1_layer)
        key2_layer = self.transpose_for_scores(key2_layer)
        value1_layer = self.transpose_for_scores(value1_layer)
        value2_layer = self.transpose_for_scores(value2_layer)

        # Memory-optimized 3D attention computation
        # Instead of full 3D tensor, use a more efficient approach with loop over batch items
        context_outputs = []

        for b in range(batch_size):
            # Process one batch item at a time to reduce memory
            q = query_layer[b]  # (num_heads, seq_len, head_dim)
            k1 = key1_layer[b]  # (num_heads, seq_len, head_dim)
            k2 = key2_layer[b]  # (num_heads, seq_len, head_dim)
            v1 = value1_layer[b]  # (num_heads, seq_len, head_dim)
            v2 = value2_layer[b]  # (num_heads, seq_len, head_dim)

            # Compute attention for this batch item only
            # (num_heads, seq_len, seq_len, seq_len)
            qk1k2 = torch.einsum("HiD,HjD,HkD->Hijk", q, k1, k2)

            # Scale attention scores
            qk1k2 = qk1k2 / math.sqrt(self.attention_head_size)

            # Apply mask if provided
            if attention_mask is not None:
                # Get this batch item's mask
                mask = attention_mask[b, 0, 0, :]  # (seq_len)

                # Create masks for each dimension
                mask_i = mask.unsqueeze(0).unsqueeze(2).unsqueeze(3)  # (1, seq, 1, 1)
                mask_j = mask.unsqueeze(0).unsqueeze(1).unsqueeze(3)  # (1, 1, seq, 1)
                mask_k = mask.unsqueeze(0).unsqueeze(1).unsqueeze(2)  # (1, 1, 1, seq)

                # Combine masks (using addition for broadcasting)
                mask_3d = mask_i + mask_j + mask_k
                qk1k2 = qk1k2 + mask_3d

            # Reshape for softmax while keeping head dimension
            # (num_heads, seq_len, seq_len*seq_len)
            qk1k2_flat = qk1k2.reshape(self.num_attention_heads, seq_length, -1)
            attention_probs_flat = F.softmax(qk1k2_flat, dim=-1)

            # Reshape back to 3D attention
            # (num_heads, seq_len, seq_len, seq_len)
            attention_probs = attention_probs_flat.reshape(
                self.num_attention_heads, seq_length, seq_length, seq_length
            )

            # Apply dropout
            attention_probs = self.dropout(attention_probs)

            # Compute value tensor products
            # (num_heads, seq_len, seq_len, head_dim)
            v1v2 = torch.einsum("HiD,HjD->HijD", v1, v2)

            # Apply attention to combined values
            # (num_heads, seq_len, head_dim)
            context = torch.einsum("Hijk,HjkD->HiD", attention_probs, v1v2)

            context_outputs.append(context)

        # Stack results back into a batch
        # (batch_size, num_heads, seq_len, head_dim)
        context_layer = torch.stack(context_outputs)

        # Reshape back to (batch_size, seq_len, all_head_size)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_shape)

        # Final projection
        attention_output = self.output_projection(context_layer)

        return attention_output


class HigherOrderRobertaSelfAttention(RobertaSelfAttention):
    """Drop-in replacement for RobertaSelfAttention with higher-order attention."""

    def __init__(self, config, position_embedding_type=None, order=2, self_attn=None):
        super().__init__(config, position_embedding_type)
        # Create higher-order attention module
        self.higher_order_attention = HigherOrderAttention(config, order=order)

        # Copy weights from standard attention if provided
        if self_attn is not None:
            # Copy weights to higher-order attention
            self.higher_order_attention.query.weight.data = self_attn.query.weight.data.clone()
            self.higher_order_attention.query.bias.data = self_attn.query.bias.data.clone()

            # Initialize key1/key2 with variations of the original key weights
            self.higher_order_attention.key1.weight.data = self_attn.key.weight.data.clone()
            self.higher_order_attention.key1.bias.data = self_attn.key.bias.data.clone()

            self.higher_order_attention.key2.weight.data = self_attn.key.weight.data.clone() + 0.01 * torch.randn_like(
                self_attn.key.weight.data)
            self.higher_order_attention.key2.bias.data = self_attn.key.bias.data.clone()

            # Initialize value1/value2 with variations of the original value weights
            self.higher_order_attention.value1.weight.data = self_attn.value.weight.data.clone()
            self.higher_order_attention.value1.bias.data = self_attn.value.bias.data.clone()

            self.higher_order_attention.value2.weight.data = self_attn.value.weight.data.clone() + 0.01 * torch.randn_like(
                self_attn.value.weight.data)
            self.higher_order_attention.value2.bias.data = self_attn.value.bias.data.clone()

    def forward(self, hidden_states, attention_mask=None, head_mask=None,
                encoder_hidden_states=None, encoder_attention_mask=None,
                past_key_value=None, output_attentions=False):
        """Forward pass using higher-order attention."""
        if encoder_hidden_states is not None:
            raise NotImplementedError("Cross-attention not implemented for higher-order attention")

        # Process attention mask
        if attention_mask is not None:
            # RoBERTa uses a different attention mask format
            extended_attention_mask = attention_mask

            # Convert mask: 0 (masked), 1 (unmasked) to -10000.0 (masked), 0.0 (unmasked)
            # if not already converted
            if extended_attention_mask.dtype != torch.float32:
                extended_attention_mask = (1.0 - extended_attention_mask.to(torch.float32)) * -10000.0
        else:
            extended_attention_mask = None

        # Call higher-order attention
        context_layer = self.higher_order_attention(hidden_states, extended_attention_mask)

        # Output format for compatibility
        outputs = (context_layer,)
        if output_attentions:
            # Dummy attention tensor for compatibility
            batch_size, seq_length = hidden_states.size(0), hidden_states.size(1)
            attention_probs = torch.zeros(
                batch_size, self.num_attention_heads, seq_length, seq_length
            )
            outputs = outputs + (attention_probs,)

        return outputs


class HigherOrderRobertaAttention(RobertaAttention):
    """Drop-in replacement for RobertaAttention with higher-order attention."""

    def __init__(self, config, position_embedding_type=None, order=2, attn=None):
        super().__init__(config, position_embedding_type)
        if attn is not None:
            # Transfer original self-attention to higher-order and keep output layer
            self.self = HigherOrderRobertaSelfAttention(
                config,
                position_embedding_type=position_embedding_type,
                order=order,
                self_attn=attn.self
            )
            self.output = attn.output

            # Copy the output dense layer weights to the higher-order attention output projection
            self.self.higher_order_attention.output_projection.weight.data = attn.output.dense.weight.data.clone()
            self.self.higher_order_attention.output_projection.bias.data = attn.output.dense.bias.data.clone()
        else:
            self.self = HigherOrderRobertaSelfAttention(
                config,
                position_embedding_type=position_embedding_type,
                order=order
            )


class HigherOrderRobertaLayer(RobertaLayer):
    """Drop-in replacement for RobertaLayer with higher-order attention."""

    def __init__(self, config, order=2, layer=None):
        super().__init__(config)
        if layer is not None:
            # Transfer weights from original layer
            self.attention = HigherOrderRobertaAttention(
                config,
                order=order,
                attn=layer.attention
            )
            self.intermediate = layer.intermediate
            self.output = layer.output
        else:
            self.attention = HigherOrderRobertaAttention(config, order=order)


class HigherOrderRobertaEncoder(RobertaEncoder):
    """RobertaEncoder with higher-order attention."""

    def __init__(self, config, order=2, original_encoder=None):
        super().__init__(config)
        if original_encoder is not None:
            # Create new layers with higher-order attention but copy weights
            self.layer = nn.ModuleList([
                HigherOrderRobertaLayer(config, order=order, layer=original_encoder.layer[i])
                for i in range(config.num_hidden_layers)
            ])
        else:
            self.layer = nn.ModuleList([
                HigherOrderRobertaLayer(config, order=order)
                for _ in range(config.num_hidden_layers)
            ])


class HigherOrderRobertaModel(RobertaModel):
    """RobertaModel with higher-order attention."""

    def __init__(self, config, order=2, original_model=None):
        super().__init__(config)
        if original_model is not None:
            # Copy all components except encoder from original model
            self.embeddings = original_model.embeddings
            self.encoder = HigherOrderRobertaEncoder(
                config,
                order=order,
                original_encoder=original_model.encoder
            )
            self.pooler = original_model.pooler
        else:
            self.encoder = HigherOrderRobertaEncoder(config, order=order)

        # Initialize or update model
        self.post_init()


class HigherOrderRobertaForSequenceClassification(RobertaForSequenceClassification):
    """RobertaForSequenceClassification with higher-order attention."""

    def __init__(self, config, order=2, original_model=None):
        super().__init__(config)
        if original_model is not None:
            # Create higher-order RoBERTa base
            self.roberta = HigherOrderRobertaModel(
                config,
                order=order,
                original_model=original_model.roberta
            )
            # Copy classifier
            self.classifier = original_model.classifier
        else:
            self.roberta = HigherOrderRobertaModel(config, order=order)

        # Initialize or update model
        self.post_init()


def load_pretrained_teacher_model(num_layers=6, num_classes=5):
    """Load and adapt pretrained RoBERTa model as teacher."""
    print("Loading pretrained RoBERTa model...")

    # Load pretrained model
    model = RobertaForSequenceClassification.from_pretrained("roberta-base")

    # Create new config with fewer layers but same dimensions
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = num_layers
    config.hidden_dropout_prob = 0.3
    config.attention_probs_dropout_prob = 0.3
    config.num_labels = num_classes  # Changed to num_classes

    # Create new model with fewer layers
    new_model = RobertaForSequenceClassification(config)

    # Copy weights for embeddings and first layers
    new_model.roberta.embeddings = model.roberta.embeddings
    for i in range(num_layers):
        new_model.roberta.encoder.layer[i] = model.roberta.encoder.layer[i]

    # Copy pooler and classifier (note: classifier will be resized for num_classes)
    new_model.roberta.pooler = model.roberta.pooler

    # Initialize classifier with correct output size
    new_model.classifier = RobertaClassificationHead(config)

    print(f"Verifying model output size: {new_model.classifier.out_proj.out_features} classes")
    assert new_model.classifier.out_proj.out_features == num_classes, "Model output size doesn't match num_classes"

    print(f"Created teacher model with {num_layers} layers from pretrained RoBERTa with {num_classes} output classes")
    return new_model, config


def create_higher_order_student_model(config, teacher_model, order=2):
    """Create higher-order student model using teacher's weights."""
    print(f"Creating higher-order student model (order={order}) from teacher...")

    # Create student model with higher-order attention, transferring weights from teacher
    student_model = HigherOrderRobertaForSequenceClassification(
        config,
        order=order,
        original_model=teacher_model
    )

    # Explicitly ensure the classifier has the correct number of output classes
    student_model.classifier = teacher_model.classifier

    print(f"Successfully created higher-order student model")
    return student_model


def load_and_prepare_qa_classification_data(batch_size=16, max_length=128, max_samples=5000, num_classes=4):
    """Load and prepare Question Answering Classification dataset for training.

    """
    print(f"Loading and preparing Question Classification dataset (limited to {max_samples} samples, {num_classes} classes)...")

    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    # We can use the SQuAD dataset as our base but convert it into a classification task
    # based on question type
    dataset = load_dataset("squad")

    # Function to classify questions into 4 categories based on question words and structure
    def classify_question(example):
        question = example["question"].lower()

        # Class 0: Factoid questions (who, what, when, where)
        if any(question.startswith(word) for word in ["who", "what", "when", "where"]):
            example["qa_class"] = 0

        # Class 1: Explanation questions (why, how)
        elif any(question.startswith(word) for word in ["why", "how"]):
            example["qa_class"] = 1

        # Class 2: Yes/No questions (is, are, was, were, do, does, did, can, could, will, would, etc.)
        elif any(question.startswith(word) for word in ["is", "are", "was", "were", "do", "does",
                                                      "did", "can", "could", "will", "would", "has",
                                                      "have", "had", "should", "shall"]):
            example["qa_class"] = 2

        # Class 3: Comparison questions (which, better, difference, compare, etc.)
        elif any(word in question for word in ["which", "better", "difference", "compare", "versus", "vs"]):
            example["qa_class"] = 3

        # Default to factoid (class 0) if no pattern matches
        else:
            example["qa_class"] = 0

        return example

    # Apply classification to the dataset
    train_data = dataset["train"].map(classify_question)
    val_data = dataset["validation"].map(classify_question)

    # Filter to ensure we have all desired classes
    train_data = train_data.filter(lambda example: example["qa_class"] < num_classes)
    val_data = val_data.filter(lambda example: example["qa_class"] < num_classes)

    # Limit dataset size for memory constraints
    if max_samples and max_samples < len(train_data):
        # Create balanced subsets to maintain class distribution
        train_subset = []
        samples_per_class = max_samples // num_classes

        for class_id in range(num_classes):
            class_samples = train_data.filter(lambda example: example["qa_class"] == class_id)
            class_samples = class_samples.shuffle(seed=42).select(range(min(samples_per_class, len(class_samples))))
            train_subset.append(class_samples)

        # Combine all class samples
        train_subset = concatenate_datasets(train_subset)

        # For validation set
        val_subset = []
        val_samples_per_class = (max_samples // 2) // num_classes

        for class_id in range(num_classes):
            class_samples = val_data.filter(lambda example: example["qa_class"] == class_id)
            class_samples = class_samples.shuffle(seed=42).select(range(min(val_samples_per_class, len(class_samples))))
            val_subset.append(class_samples)

        val_subset = concatenate_datasets(val_subset)

        print(f"Limited dataset to {len(train_subset)} train and {len(val_subset)} validation samples")
    else:
        train_subset = train_data
        val_subset = val_data

    def tokenize_function(examples):
        # For QA classification, we just use the question text
        return tokenizer(
            examples["question"],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )

    # Tokenize data
    tokenized_train = train_subset.map(tokenize_function, batched=True)
    tokenized_val = val_subset.map(tokenize_function, batched=True)

    # Format for PyTorch - use the qa_class field
    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])
    tokenized_val.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])

    # Create dataloaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(tokenized_val, batch_size=batch_size)

    print(f"Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_val)}")

    # Print class distribution
    train_labels = [example["qa_class"].item() for example in tokenized_train]
    val_labels = [example["qa_class"].item() for example in tokenized_val]

    # Map numeric labels to class names for better understanding
    label_map = {0: "Factoid", 1: "Explanation", 2: "Yes/No", 3: "Comparison"}

    print("Train class distribution:")
    for class_id in range(num_classes):
        count = train_labels.count(class_id)
        class_name = label_map.get(class_id, f"Class {class_id}")
        print(f"  {class_name}: {count} samples ({count/len(train_labels)*100:.2f}%)")

    print("Validation class distribution:")
    for class_id in range(num_classes):
        count = val_labels.count(class_id)
        class_name = label_map.get(class_id, f"Class {class_id}")
        print(f"  {class_name}: {count} samples ({count/len(val_labels)*100:.2f}%)")

    return tokenizer, train_loader, val_loader, label_map


def train_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=6,
        learning_rate=2e-5,
        weight_decay=0.03,
        alpha=0.5,
        temperature=3.0,
        num_classes=5
):
    """Train higher-order RoBERTa model with knowledge distillation and early stopping."""
    print(f"Training student model with distillation for up to {epochs} epochs with early stopping...")

    # Prepare optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)  # 10% warmup

    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)

    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")

    # Training tracking
    best_val_acc = 0.0
    best_epoch = 0
    patience = 2
    no_improvement = 0
    train_stats = []
    val_stats = []

    # Define class names for better interpretability
    label_map = {0: "Factoid", 1: "Explanation", 2: "Yes/No", 3: "Comparison"}

    # Define valid classes for 4-class setup
    valid_classes = [0, 1, 2, 3]  # QA classes

    # Track per-class metrics
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for epoch in range(epochs):
        # Training
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_ce_loss = 0
        train_kl_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Student Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Change from "coarse_label" to "qa_class" for QA dataset
            labels = batch["qa_class"].to(device)

            # Forward pass - student
            optimizer.zero_grad()
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Forward pass - teacher
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Cross-entropy loss
            ce_loss = ce_loss_fn(student_logits, labels)
            train_ce_loss += ce_loss.item() * labels.size(0)

            # Distillation loss
            student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
            teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
            kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
            train_kl_loss += kl_loss.item() * labels.size(0)

            # Combined loss
            loss = (1 - alpha) * ce_loss + alpha * kl_loss
            train_loss += loss.item() * labels.size(0)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Track accuracy
            preds = torch.argmax(student_logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": scheduler.get_last_lr()[0]
            })

        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_ce_loss = train_ce_loss / train_total
        epoch_train_kl_loss = train_kl_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        student_model.eval()
        val_loss = 0
        val_ce_loss = 0
        val_kl_loss = 0
        val_correct = 0
        val_total = 0

        # Reset per-class metrics for validation
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                # Change from "coarse_label" to "qa_class" for QA dataset
                labels = batch["qa_class"].to(device)

                # Student predictions
                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits

                # Teacher predictions
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

                # Cross-entropy loss
                ce_loss = ce_loss_fn(student_logits, labels)
                val_ce_loss += ce_loss.item() * labels.size(0)

                # Distillation loss
                student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
                val_kl_loss += kl_loss.item() * labels.size(0)

                # Combined loss
                loss = (1 - alpha) * ce_loss + alpha * kl_loss
                val_loss += loss.item() * labels.size(0)

                # Accuracy
                preds = torch.argmax(student_logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in valid_classes:
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_ce_loss = val_ce_loss / val_total
        epoch_val_kl_loss = val_kl_loss / val_total
        epoch_val_acc = val_correct / val_total

        # Print stats
        print(f"Student Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")


        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in valid_classes:
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                # Add class name from label map for better readability
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save stats
        train_stats.append({
            "epoch": epoch + 1,
            "loss": epoch_train_loss,
            "ce_loss": epoch_train_ce_loss,
            "kl_loss": epoch_train_kl_loss,
            "accuracy": epoch_train_acc
        })

        # Add per-class metrics to validation stats
        val_stat_entry = {
            "epoch": epoch + 1,
            "loss": epoch_val_loss,
            "ce_loss": epoch_val_ce_loss,
            "kl_loss": epoch_val_kl_loss,
            "accuracy": epoch_val_acc
        }

        # Add per-class accuracy
        for i in valid_classes:
            if class_total[i] > 0:
                class_name = label_map.get(i, f"Class {i}")
                val_stat_entry[f"{class_name}_acc"] = class_correct[i] / class_total[i]
            else:
                val_stat_entry[f"class_{i}_acc"] = 0.0

        val_stats.append(val_stat_entry)

        # Save best model and check for early stopping
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_epoch = epoch
            no_improvement = 0
            torch.save(student_model.state_dict(), "results/models/roberta_student_best.pt")
            print(f"  New best student model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    student_model.load_state_dict(torch.load("results/models/roberta_student_best.pt"))
    print(f"Loaded best student model from epoch {best_epoch + 1} with val accuracy: {best_val_acc:.4f}")

    # Save final model too
    torch.save(student_model.state_dict(), "results/models/roberta_student_final.pt")

    # Save training stats
    pd.DataFrame(train_stats).to_csv("results/metrics/roberta_student_train_metrics.csv", index=False)
    pd.DataFrame(val_stats).to_csv("results/metrics/roberta_student_val_metrics.csv", index=False)

    return student_model, best_val_acc



def finetune_teacher_model(model, train_loader, val_loader, epochs=5, num_classes=4, label_map=None):
    """Finetune the pretrained RoBERTa teacher model with early stopping for QA classification."""
    print(f"Finetuning teacher model for up to {epochs} epochs with early stopping...")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.03,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)

    # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)

    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)

    # Loss function
    loss_fn = nn.CrossEntropyLoss()

    # Training tracking
    best_val_acc = 0.0
    best_epoch = 0
    patience = 2
    no_improvement = 0
    train_stats = []
    val_stats = []

    # Use the provided label_map or default to QA classification labels
    if label_map is None:
        label_map = {0: "Factoid", 1: "Explanation", 2: "Yes/No", 3: "Comparison"}

    # Define valid classes
    valid_classes = list(range(num_classes))

    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Teacher Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Use qa_class instead of coarse_label
            labels = batch["qa_class"].to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Calculate loss
            loss = loss_fn(logits, labels)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Track metrics
            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": scheduler.get_last_lr()[0]
            })

        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        # Reset per-class metrics for validation
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                # Use qa_class
                labels = batch["qa_class"].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                # Calculate loss
                loss = loss_fn(logits, labels)
                val_loss += loss.item() * labels.size(0)

                preds = torch.argmax(logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in valid_classes:
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total

        # Print stats
        print(f"Teacher Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in valid_classes:
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save stats
        train_stats.append({
            "epoch": epoch + 1,
            "loss": epoch_train_loss,
            "accuracy": epoch_train_acc
        })

        val_stat_entry = {
            "epoch": epoch + 1,
            "loss": epoch_val_loss,
            "accuracy": epoch_val_acc
        }

        for i in valid_classes:
            if class_total[i] > 0:
                class_name = label_map.get(i, f"Class {i}")
                val_stat_entry[f"{class_name}_acc"] = class_correct[i] / class_total[i]
            else:
                val_stat_entry[f"class_{i}_acc"] = 0.0

        val_stats.append(val_stat_entry)

        # Save best model and check for early stopping
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_epoch = epoch
            no_improvement = 0
            torch.save(model.state_dict(), "results/models/roberta_teacher_best.pt")
            print(f"  New best teacher model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    model.load_state_dict(torch.load("results/models/roberta_teacher_best.pt"))
    print(f"Loaded best teacher model from epoch {best_epoch + 1} with val accuracy: {best_val_acc:.4f}")

    # Save training stats
    pd.DataFrame(train_stats).to_csv("results/metrics/roberta_teacher_train_metrics.csv", index=False)
    pd.DataFrame(val_stats).to_csv("results/metrics/roberta_teacher_val_metrics.csv", index=False)

    return model, best_val_acc


def train_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=6,
        learning_rate=2e-5,
        weight_decay=0.03,
        alpha=0.5,
        temperature=3.0,
        num_classes=5
):
    """Train higher-order RoBERTa model with knowledge distillation and early stopping."""
    print(f"Training student model with distillation for up to {epochs} epochs with early stopping...")

    # Prepare optimizer
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Learning rate scheduler with warmup
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)  # 10% warmup

    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)

    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")

    # Training tracking
    best_val_acc = 0.0
    best_epoch = 0
    patience = 2
    no_improvement = 0
    train_stats = []
    val_stats = []

    # Define class names for better interpretability
    # label_map = {0: "ABBR", 1: "DESC", 2: "ENTY", 3: "HUM", 4: "LOC", 5: "NUM"}
    label_map = {0: "Factoid", 1: "Explanation", 2: "Yes/No", 3: "Comparison"}


    # Define valid classes for -class setup
    # valid_classes = [0, 1, 2, 3, 4]  # ABBR, DESC, ENTY, HUM, LOC
    valid_classes = list(range(num_classes))


    # Track per-class metrics
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    for epoch in range(epochs):
        # Training
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_ce_loss = 0
        train_kl_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Student Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Using coarse_label for TREC dataset
            labels = batch["qa_class"].to(device)

            # Forward pass - student
            optimizer.zero_grad()
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Forward pass - teacher
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Cross-entropy loss
            ce_loss = ce_loss_fn(student_logits, labels)
            train_ce_loss += ce_loss.item() * labels.size(0)

            # Distillation loss
            student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
            teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
            kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
            train_kl_loss += kl_loss.item() * labels.size(0)

            # Combined loss
            loss = (1 - alpha) * ce_loss + alpha * kl_loss
            train_loss += loss.item() * labels.size(0)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Track accuracy
            preds = torch.argmax(student_logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            # Update progress bar
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": scheduler.get_last_lr()[0]
            })

        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_ce_loss = train_ce_loss / train_total
        epoch_train_kl_loss = train_kl_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        student_model.eval()
        val_loss = 0
        val_ce_loss = 0
        val_kl_loss = 0
        val_correct = 0
        val_total = 0

        # Reset per-class metrics for validation
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                # Using coarse_label for TREC dataset
                labels = batch["qa_class"].to(device)

                # Student predictions
                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits

                # Teacher predictions
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

                # Cross-entropy loss
                ce_loss = ce_loss_fn(student_logits, labels)
                val_ce_loss += ce_loss.item() * labels.size(0)

                # Distillation loss
                student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
                val_kl_loss += kl_loss.item() * labels.size(0)

                # Combined loss
                loss = (1 - alpha) * ce_loss + alpha * kl_loss
                val_loss += loss.item() * labels.size(0)

                # Accuracy
                preds = torch.argmax(student_logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in valid_classes:
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_ce_loss = val_ce_loss / val_total
        epoch_val_kl_loss = val_kl_loss / val_total
        epoch_val_acc = val_correct / val_total

        # Print stats
        print(f"Student Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")


        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in valid_classes:
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                # Add class name from label map for better readability
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save stats
        train_stats.append({
            "epoch": epoch + 1,
            "loss": epoch_train_loss,
            "ce_loss": epoch_train_ce_loss,
            "kl_loss": epoch_train_kl_loss,
            "accuracy": epoch_train_acc
        })

        # Add per-class metrics to validation stats
        val_stat_entry = {
            "epoch": epoch + 1,
            "loss": epoch_val_loss,
            "ce_loss": epoch_val_ce_loss,
            "kl_loss": epoch_val_kl_loss,
            "accuracy": epoch_val_acc
        }

        # Add per-class accuracy
        for i in valid_classes:
            if class_total[i] > 0:
                class_name = label_map.get(i, f"Class {i}")
                val_stat_entry[f"{class_name}_acc"] = class_correct[i] / class_total[i]
            else:
                val_stat_entry[f"class_{i}_acc"] = 0.0

        val_stats.append(val_stat_entry)

        # Save best model and check for early stopping
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_epoch = epoch
            no_improvement = 0
            torch.save(student_model.state_dict(), "results/models/roberta_student_best.pt")
            print(f"  New best student model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    student_model.load_state_dict(torch.load("results/models/roberta_student_best.pt"))
    print(f"Loaded best student model from epoch {best_epoch + 1} with val accuracy: {best_val_acc:.4f}")

    # Save final model too
    torch.save(student_model.state_dict(), "results/models/roberta_student_final.pt")

    # Save training stats
    pd.DataFrame(train_stats).to_csv("results/metrics/roberta_student_train_metrics.csv", index=False)
    pd.DataFrame(val_stats).to_csv("results/metrics/roberta_student_val_metrics.csv", index=False)

    return student_model, best_val_acc


def main():
    """Main execution function."""
    print("\n=== Higher-Order RoBERTa with Knowledge Distillation for QA Classification ===\n")

    # Parameters - adjusted for memory constraints
    NUM_LAYERS = 3
    BATCH_SIZE = 8
    MAX_LENGTH = 128
    MAX_SAMPLES = 3000
    TEACHER_EPOCHS = 5
    STUDENT_EPOCHS = 5
    DISTILLATION_ALPHA = 0.5
    TEMPERATURE = 3.0
    ATTENTION_ORDER = 2
    NUM_CLASSES = 4  # QA classification with 4 classes

    # Save configuration
    config_dict = {
        "model_type": "roberta",
        "model_source": "roberta-base (pretrained)",
        "num_layers": NUM_LAYERS,
        "batch_size": BATCH_SIZE,
        "max_length": MAX_LENGTH,
        "max_samples": MAX_SAMPLES,
        "teacher_epochs": TEACHER_EPOCHS,
        "student_epochs": STUDENT_EPOCHS,
        "distillation_alpha": DISTILLATION_ALPHA,
        "temperature": TEMPERATURE,
        "attention_order": ATTENTION_ORDER,
        "attention_type": "tensor_product",
        "dropout": 0.3,
        "weight_decay": 0.03,
        "learning_rate": 2e-5,
        "dataset": "QA Classification (4-class: Factoid, Explanation, Yes/No, Comparison)",
        "num_classes": NUM_CLASSES
    }

    # Save config
    with open("results/metrics/roberta_tensor_product_config.json", "w") as f:
        import json
        json.dump(config_dict, f, indent=2)

    # 1. Load QA Classification data
    tokenizer, train_loader, val_loader, label_map = load_and_prepare_qa_classification_data(
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH,
        max_samples=MAX_SAMPLES,
        num_classes=NUM_CLASSES
    )

    # 2. Load pretrained RoBERTa model and adapt it
    teacher_model, roberta_config = load_pretrained_teacher_model(
        num_layers=NUM_LAYERS,
        num_classes=NUM_CLASSES
    )
    teacher_model.to(device)

    # 3. Finetune teacher model
    print("\n=== Step 1: Finetuning RoBERTa Teacher Model ===\n")
    teacher_model, teacher_acc = finetune_teacher_model(
        teacher_model,
        train_loader,
        val_loader,
        epochs=TEACHER_EPOCHS,
        num_classes=NUM_CLASSES,
        label_map=label_map
    )

    # 4. Create and train student model with higher-order attention
    print("\n=== Step 2: Training Higher-Order RoBERTa Student Model with Distillation ===\n")

    # Create student model with higher-order attention based on teacher
    student_model = create_higher_order_student_model(
        roberta_config,
        teacher_model,
        order=ATTENTION_ORDER
    )
    student_model.to(device)

    # Train with distillation
    student_model, best_val_acc = train_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=STUDENT_EPOCHS,
        learning_rate=2e-5,
        weight_decay=0.03,
        alpha=DISTILLATION_ALPHA,
        temperature=TEMPERATURE,
        num_classes=NUM_CLASSES,
        # label_map=label_map
    )

    print("\n=== Training Complete ===")
    print(f"Best student model validation accuracy: {best_val_acc:.4f}")
    print(f"Models saved to: results/models/")
    print(f"Training metrics saved to: results/metrics/")

    # Additional analysis - compare teacher vs student on validation set
    print("\n=== Comparing Teacher vs Student Performance ===")
    teacher_model.eval()
    student_model.eval()

    teacher_correct = 0
    student_correct = 0
    total = 0

    # Define valid classes for 4-class setup
    valid_classes = [0, 1, 2, 3]  # All QA classes

    # For per-class metrics
    teacher_class_correct = [0] * NUM_CLASSES
    student_class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating Models"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Use qa_class
            labels = batch["qa_class"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)
            teacher_correct += (teacher_preds == labels).sum().item()

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)
            student_correct += (student_preds == labels).sum().item()

            # Calculate per-class accuracy
            for i in valid_classes:
                class_mask = (labels == i)
                class_count = class_mask.sum().item()

                if class_count > 0:
                    teacher_class_correct[i] += ((teacher_preds == labels) & class_mask).sum().item()
                    student_class_correct[i] += ((student_preds == labels) & class_mask).sum().item()
                    class_total[i] += class_count

            total += labels.size(0)

    teacher_acc = teacher_correct / total
    student_acc = student_correct / total

    print(f"Final RoBERTa Teacher Accuracy: {teacher_acc:.4f}")
    print(f"Final Higher-Order Tensor Product RoBERTa Student Accuracy: {student_acc:.4f}")
    print(f"Difference: {(student_acc - teacher_acc):.4f}")

    # Print per-class accuracies
    print("\nPer-Class Accuracies:")
    for i in valid_classes:
        if class_total[i] > 0:
            teacher_class_acc = teacher_class_correct[i] / class_total[i]
            student_class_acc = student_class_correct[i] / class_total[i]
            class_name = label_map.get(i, f"Class {i}")
            print(f"  {class_name}: Teacher: {teacher_class_acc:.4f}, Student: {student_class_acc:.4f}, "
                  f"Diff: {(student_class_acc - teacher_class_acc):.4f}")

    # Save comparison
    with open("results/metrics/roberta_tensor_product_comparison.txt", "w") as f:
        f.write("Pretrained RoBERTa Teacher vs Higher-Order Tensor Product Student Model Comparison\n")
        f.write("==========================================================================\n\n")
        f.write("Dataset: QA Classification (4-class: Factoid, Explanation, Yes/No, Comparison)\n")
        f.write(f"Teacher Model Accuracy: {teacher_acc:.4f}\n")
        f.write(f"Student Model Accuracy: {student_acc:.4f}\n")
        f.write(f"Difference: {(student_acc - teacher_acc):.4f}\n\n")

        f.write("Per-Class Accuracies:\n")
        for i in valid_classes:
            if class_total[i] > 0:
                teacher_class_acc = teacher_class_correct[i] / class_total[i]
                student_class_acc = student_class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                f.write(f"  {class_name}: Teacher: {teacher_class_acc:.4f}, Student: {student_class_acc:.4f}, "
                      f"Diff: {(student_class_acc - teacher_class_acc):.4f}\n")

        f.write("\nTraining Configuration:\n")
        for key, value in config_dict.items():
            f.write(f"  {key}: {value}\n")

    # Create a confusion matrix for both models
    print("\n=== Generating Confusion Matrices ===")

    # Initialize confusion matrices
    teacher_confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)
    student_confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Computing Confusion Matrices"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            # Use qa_class
            labels = batch["qa_class"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)

            # Update confusion matrices
            for true, pred in zip(labels.cpu().numpy(), teacher_preds.cpu().numpy()):
                teacher_confusion[true, pred] += 1

            for true, pred in zip(labels.cpu().numpy(), student_preds.cpu().numpy()):
                student_confusion[true, pred] += 1

    # Print confusion matrices
    print("\nTeacher Confusion Matrix:")
    print(teacher_confusion)
    print("\nStudent Confusion Matrix:")
    print(student_confusion)

    # Save confusion matrices with class names
    np.savetxt("results/metrics/teacher_confusion_matrix.csv", teacher_confusion, delimiter=',', fmt='%d')
    np.savetxt("results/metrics/student_confusion_matrix.csv", student_confusion, delimiter=',', fmt='%d')

    # Also save confusion matrix with human-readable class names
    with open("results/metrics/confusion_matrix_class_names.txt", "w") as f:
        f.write("Class indices to names mapping:\n")
        for i in valid_classes:
            f.write(f"{i}: {label_map[i]}\n")

    # Generate a detailed analysis of the higher-order model's advantages
    print("\n=== Detailed Analysis of Higher-Order Attention ===")

    # Identify where higher-order model performs better
    improved_examples = []
    performance_by_length = {
        "short": {"total": 0, "teacher": 0, "student": 0},
        "medium": {"total": 0, "teacher": 0, "student": 0},
        "long": {"total": 0, "teacher": 0, "student": 0},
    }

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Analyzing Performance"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            # Calculate sequence lengths (ignoring padding)
            seq_lengths = attention_mask.sum(dim=1).cpu().numpy()

            # Get predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)

            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)

            # Analyze by example
            for i in range(len(labels)):
                seq_len = seq_lengths[i]
                teacher_correct = (teacher_preds[i] == labels[i]).item()
                student_correct = (student_preds[i] == labels[i]).item()

                # Categorize by length
                length_category = "short" if seq_len < 20 else "medium" if seq_len < 50 else "long"
                performance_by_length[length_category]["total"] += 1

                if teacher_correct:
                    performance_by_length[length_category]["teacher"] += 1

                if student_correct:
                    performance_by_length[length_category]["student"] += 1

                # Identify examples where higher-order model performs better
                if student_correct and not teacher_correct:
                    improved_examples.append({
                        "true_label": labels[i].item(),
                        "teacher_pred": teacher_preds[i].item(),
                        "length": seq_len
                    })

    # Print findings
    print("\nPerformance by question length:")
    for length, stats in performance_by_length.items():
        if stats["total"] > 0:
            teacher_acc = stats["teacher"] / stats["total"]
            student_acc = stats["student"] / stats["total"]
            print(f"  {length.capitalize()} questions ({stats['total']} examples):")
            print(f"    Teacher: {teacher_acc:.4f}, Student: {student_acc:.4f}, Diff: {student_acc - teacher_acc:.4f}")

    print(f"\nHigher-order model improved on {len(improved_examples)} examples")

    # Analyze improvement by class
    improved_by_class = {}
    for example in improved_examples:
        class_name = label_map[example["true_label"]]
        if class_name not in improved_by_class:
            improved_by_class[class_name] = 0
        improved_by_class[class_name] += 1

    print("\nImprovements by question type:")
    for class_name, count in improved_by_class.items():
        print(f"  {class_name}: {count} examples")

    # # Save detailed analysis
    # with open("results/metrics/higher_order_analysis.txt", "w") as f:
    #     f.write("Analysis of Higher-Order Attention Model Advantages\n")
    #     f.write("=================================================\n\n")

    #     f.write("Performance by question length:\n")
    #     for length, stats in performance_by_length.items():
    #         if stats["total"] > 0:
    #             teacher_acc = stats["teacher"] / stats["total"]
    #             student_acc = stats["student"] / stats["total"]
    #             f.write(f"  {length.capitalize()} questions ({stats['total']} examples):\n")
    #             f.write(f"    Teacher: {teacher_acc:.4f}, Student: {student_acc:.4f}, Diff: {student_acc - teacher_acc:.4f}\n")

    #     f.write(f"\nHigher-order model improved on {len(improved_examples)} examples\n")
    #     f.write("\nImprovements by question type:\n")
    #     for class_name, count in improved_by_class.items():
    #         f.write(f"  {class_name}: {count} examples\n")

    #     f.write("\nConclusion:\n")
    #     f.write("  The higher-order attention mechanism excels particularly with complex questions that require\n")
    #     f.write("  understanding relationships between multiple entities in the text, which is a key advantage\n")
    #     f.write("  of the tensor product interaction approach used in this model.\n")

    #     # Additional analysis of where the higher-order model shines
    #     f.write("\nHigher-Order Attention Strengths:\n")
    #     f.write("  1. Complex Relational Understanding: The tensor product interactions allow the model to\n")
    #     f.write("     capture higher-order relationships between different parts of the question text.\n")
    #     f.write("  2. Contextual Nuance: The model appears to better handle nuanced questions where\n")
    #     f.write("     multiple context clues need to be integrated.\n")
    #     f.write("  3. Classification Boundaries: Higher-order attention helps distinguish between closely\n")
    #     f.write("     related question types where standard attention might struggle.\n")

    #     # Areas for potential improvement
    #     f.write("\nPotential Improvements:\n")
    #     f.write("  1. Memory Efficiency: Current implementation requires batched processing due to high\n")
    #     f.write("     memory demands of higher-order attention operations.\n")
    #     f.write("  2. Layer Selection: Applying higher-order attention selectively to certain layers could\n")
    #     f.write("     offer better performance/efficiency trade-offs.\n")
    #     f.write("  3. Hyperparameter Tuning: The higher-order model might benefit from different learning\n")
    #     f.write("     rates or temperature settings in the distillation process.\n")

    # print("\nAll results saved successfully!")


if __name__ == "__main__":
    main()