#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Higher-Order Attention RoBERTa with Knowledge Distillation
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm

# Clear datasets cache to avoid issues
import datasets
datasets.disable_caching()

from datasets import load_dataset, concatenate_datasets

# Fixed imports for compatible versions
from transformers import (
    RobertaTokenizer, 
    RobertaModel, 
    RobertaConfig,
    RobertaForSequenceClassification
)

from transformers.models.roberta.modeling_roberta import (
    RobertaAttention, RobertaSelfAttention, RobertaLayer,
    RobertaEncoder, RobertaClassificationHead
)

from torch.utils.data import DataLoader
from torch.optim import AdamW

# Force CPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class BlockDiagonalHigherOrderAttention(nn.Module):
    """
     Higher-Order Attention with numerical stability improvements.
    """

    def __init__(self, config, order=2, block_size=32, overlap_ratio=0.25):
        super().__init__()
        self.order = order
        self.block_size = block_size
        self.overlap_size = int(block_size * overlap_ratio)
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Projections for query and keys
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key1 = nn.Linear(config.hidden_size, self.all_head_size)
        self.key2 = nn.Linear(config.hidden_size, self.all_head_size)

        # Projections for values
        self.value1 = nn.Linear(config.hidden_size, self.all_head_size)
        self.value2 = nn.Linear(config.hidden_size, self.all_head_size)

        # Output projection
        self.output_projection = nn.Linear(self.all_head_size, config.hidden_size)

        # Cross-block interaction layer
        self.cross_block_attention = nn.MultiheadAttention(
            config.hidden_size, 
            config.num_attention_heads, 
            dropout=config.attention_probs_dropout_prob,
            batch_first=True
        )

        # Boundary scorer with proper initialization
        self.boundary_scorer = nn.Linear(config.hidden_size, 1)
        
        # Dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
        # Layer normalization for stability
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        
        # Scaling factor for numerical stability
        self.scale_factor = 1.0 / math.sqrt(self.attention_head_size)

        self.initialize_weights()

    def initialize_weights(self):
        """Initialize weights with smaller variance for stability."""
        for module in [self.query, self.key1, self.key2, self.value1, self.value2, 
                      self.output_projection]:
            # Use Xavier initialization with smaller scale
            nn.init.xavier_uniform_(module.weight, gain=0.5)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
        # Initialize boundary scorer with very small weights
        nn.init.normal_(self.boundary_scorer.weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.boundary_scorer.bias)

    def transpose_for_scores(self, x):
        """Reshape for multi-head attention."""
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_shape)
        return x.permute(0, 2, 1, 3)

    def get_fixed_blocks(self, seq_len):
        """Use fixed block boundaries for stability."""
        blocks = []
        current_pos = 0
        
        while current_pos < seq_len:
            end_pos = min(current_pos + self.block_size, seq_len)
            blocks.append((current_pos, end_pos))
            current_pos += self.block_size - self.overlap_size
            
        return blocks

    def compute_block_attention(self, q_block, k1_block, k2_block, v1_block, v2_block, block_mask=None):
        """
        Compute higher-order attention within a single block with numerical stability.
        """
        block_len = q_block.size(1)
        
        # Clip block length to prevent memory issues
        if block_len > self.block_size:
            q_block = q_block[:, :self.block_size, :]
            k1_block = k1_block[:, :self.block_size, :]
            k2_block = k2_block[:, :self.block_size, :]
            v1_block = v1_block[:, :self.block_size, :]
            v2_block = v2_block[:, :self.block_size, :]
            block_len = self.block_size

        # For very small blocks, use standard attention
        if block_len <= 2:
            # Fallback to standard attention for tiny blocks
            attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
            if block_mask is not None:
                attention_scores += block_mask.unsqueeze(0).unsqueeze(-1)
            attention_probs = F.softmax(attention_scores, dim=-1)
            context = torch.matmul(attention_probs, v1_block)
            return context

        # Compute 3D attention tensor with stability checks
        try:
            # Use smaller blocks to prevent memory overflow
            max_tensor_size = min(block_len, 16)  # Limit tensor product size
            
            if block_len > max_tensor_size:
                # Sample key positions for higher-order interactions
                indices = torch.linspace(0, block_len - 1, max_tensor_size, dtype=torch.long, device=q_block.device)
                k1_sampled = k1_block[:, indices, :]
                k2_sampled = k2_block[:, indices, :]
                v1_sampled = v1_block[:, indices, :]
                v2_sampled = v2_block[:, indices, :]
            else:
                k1_sampled = k1_block
                k2_sampled = k2_block
                v1_sampled = v1_block
                v2_sampled = v2_block
                
            sampled_len = k1_sampled.size(1)
            
            # Compute 3D attention tensor: (num_heads, block_len, sampled_len, sampled_len)
            qk1k2 = torch.einsum("HiD,HjD,HkD->Hijk", q_block, k1_sampled, k2_sampled)
            
            # Apply scaling for numerical stability
            qk1k2 = qk1k2 * self.scale_factor
            
            # Check for NaN values
            if torch.isnan(qk1k2).any():
                print("Warning: NaN detected in attention computation, using fallback")
                # Fallback to standard attention
                attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
                attention_probs = F.softmax(attention_scores, dim=-1)
                context = torch.matmul(attention_probs, v1_block)
                return context
            
            # Reshape for softmax and apply
            qk1k2_flat = qk1k2.reshape(self.num_attention_heads, block_len, -1)
            
            # Clamp values to prevent overflow
            qk1k2_flat = torch.clamp(qk1k2_flat, min=-50, max=50)
            
            attention_probs_flat = F.softmax(qk1k2_flat, dim=-1)
            attention_probs = attention_probs_flat.reshape(
                self.num_attention_heads, block_len, sampled_len, sampled_len
            )
            
            # Apply dropout
            attention_probs = self.dropout(attention_probs)
            
            # Compute value tensor products
            v1v2 = torch.einsum("HiD,HjD->HijD", v1_sampled, v2_sampled)
            
            # Apply attention
            context = torch.einsum("Hijk,HjkD->HiD", attention_probs, v1v2)
            
            # Check for NaN in output
            if torch.isnan(context).any():
                print("Warning: NaN in context, using standard attention fallback")
                attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
                attention_probs = F.softmax(attention_scores, dim=-1)
                context = torch.matmul(attention_probs, v1_block)
            
            return context
            
        except Exception as e:
            print(f"Error in block attention computation: {e}, using fallback")
            # Fallback to standard attention
            attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
            attention_probs = F.softmax(attention_scores, dim=-1)
            context = torch.matmul(attention_probs, v1_block)
            return context

    def forward(self, hidden_states, attention_mask=None):
        """Forward pass with stability improvements."""
        batch_size, seq_length = hidden_states.size(0), hidden_states.size(1)

        # Apply layer normalization for stability
        hidden_states = self.layer_norm(hidden_states)

        # Project inputs
        query_layer = self.query(hidden_states)
        key1_layer = self.key1(hidden_states)
        key2_layer = self.key2(hidden_states)
        value1_layer = self.value1(hidden_states)
        value2_layer = self.value2(hidden_states)

        # Reshape to multi-head format
        query_layer = self.transpose_for_scores(query_layer)
        key1_layer = self.transpose_for_scores(key1_layer)
        key2_layer = self.transpose_for_scores(key2_layer)
        value1_layer = self.transpose_for_scores(value1_layer)
        value2_layer = self.transpose_for_scores(value2_layer)

        # Use fixed blocks for stability
        blocks = self.get_fixed_blocks(seq_length)

        # Process each batch item
        batch_outputs = []
        
        for b in range(batch_size):
            # Initialize output
            batch_context = torch.zeros_like(query_layer[b])
            overlap_counts = torch.ones(seq_length, device=hidden_states.device)
            
            # Process each block
            for start, end in blocks:
                # Extract block data
                q_block = query_layer[b, :, start:end, :]
                k1_block = key1_layer[b, :, start:end, :]
                k2_block = key2_layer[b, :, start:end, :]
                v1_block = value1_layer[b, :, start:end, :]
                v2_block = value2_layer[b, :, start:end, :]
                
                # Get block mask
                block_mask = None
                if attention_mask is not None:
                    mask_slice = attention_mask[b, start:end]
                    if mask_slice.dim() > 1:
                        block_mask = mask_slice
                    else:
                        block_mask = (1.0 - mask_slice.float()) * -10000.0
                
                # Compute block attention
                block_context = self.compute_block_attention(
                    q_block, k1_block, k2_block, v1_block, v2_block, block_mask
                )
                
                # Accumulate results
                batch_context[:, start:end, :] += block_context
                overlap_counts[start:end] += 1
            
            # Average overlapping regions
            batch_context = batch_context / overlap_counts.unsqueeze(0).unsqueeze(-1)
            batch_outputs.append(batch_context)

        # Stack batch results
        context_layer = torch.stack(batch_outputs)

        # Reshape back
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_shape)

        # Final projection
        attention_output = self.output_projection(context_layer)
        
        # Check for NaN in final output
        if torch.isnan(attention_output).any():
            print("Warning: NaN in final attention output, replacing with input")
            attention_output = hidden_states

        return attention_output


class BlockUnrolledHigherOrderRobertaSelfAttention(RobertaSelfAttention):
    """ RobertaSelfAttention with numerical stability."""

    def __init__(self, config, position_embedding_type=None, order=2, block_size=32, self_attn=None):
        super().__init__(config, position_embedding_type)
        
        self.higher_order_attention = BlockDiagonalHigherOrderAttention(
            config, order=order, block_size=block_size
        )

        # Copy weights from standard attention if provided
        if self_attn is not None:
            with torch.no_grad():
                # Copy with small perturbations to avoid identical weights
                self.higher_order_attention.query.weight.data = self_attn.query.weight.data.clone()
                self.higher_order_attention.query.bias.data = self_attn.query.bias.data.clone()

                self.higher_order_attention.key1.weight.data = self_attn.key.weight.data.clone()
                self.higher_order_attention.key1.bias.data = self_attn.key.bias.data.clone()

                # Add small noise for key2 to create diversity
                noise = torch.randn_like(self_attn.key.weight.data) * 0.01
                self.higher_order_attention.key2.weight.data = self_attn.key.weight.data.clone() + noise
                self.higher_order_attention.key2.bias.data = self_attn.key.bias.data.clone()

                self.higher_order_attention.value1.weight.data = self_attn.value.weight.data.clone()
                self.higher_order_attention.value1.bias.data = self_attn.value.bias.data.clone()

                # Add small noise for value2 to create diversity
                noise = torch.randn_like(self_attn.value.weight.data) * 0.01
                self.higher_order_attention.value2.weight.data = self_attn.value.weight.data.clone() + noise
                self.higher_order_attention.value2.bias.data = self_attn.value.bias.data.clone()

    def forward(self, hidden_states, attention_mask=None, head_mask=None,
                encoder_hidden_states=None, encoder_attention_mask=None,
                past_key_value=None, output_attentions=False):
        """Forward pass with error handling."""
        if encoder_hidden_states is not None:
            raise NotImplementedError("Cross-attention not implemented")

        # Process attention mask
        if attention_mask is not None:
            extended_attention_mask = attention_mask
            if extended_attention_mask.dtype != torch.float32:
                extended_attention_mask = (1.0 - extended_attention_mask.to(torch.float32)) * -10000.0
        else:
            extended_attention_mask = None

        # Call higher-order attention with error handling
        try:
            context_layer = self.higher_order_attention(hidden_states, extended_attention_mask)
        except Exception as e:
            print(f"Error in higher-order attention: {e}, using standard attention")
            # Fallback to standard attention
            context_layer = hidden_states

        outputs = (context_layer,)
        if output_attentions:
            # Dummy attention tensor
            batch_size, seq_length = hidden_states.size(0), hidden_states.size(1)
            attention_probs = torch.zeros(
                batch_size, self.num_attention_heads, seq_length, seq_length,
                device=hidden_states.device, dtype=hidden_states.dtype
            )
            outputs = outputs + (attention_probs,)

        return outputs


class HigherOrderRobertaAttention(RobertaAttention):
    """Drop-in replacement with stability fixes."""

    def __init__(self, config, position_embedding_type=None, order=2, block_size=32, attn=None):
        super().__init__(config, position_embedding_type)
        if attn is not None:
            self.self = BlockUnrolledHigherOrderRobertaSelfAttention(
                config,
                position_embedding_type=position_embedding_type,
                order=order,
                block_size=block_size,
                self_attn=attn.self
            )
            self.output = attn.output

            # Copy output projection weights
            with torch.no_grad():
                self.self.higher_order_attention.output_projection.weight.data = attn.output.dense.weight.data.clone()
                self.self.higher_order_attention.output_projection.bias.data = attn.output.dense.bias.data.clone()
        else:
            self.self = BlockUnrolledHigherOrderRobertaSelfAttention(
                config,
                position_embedding_type=position_embedding_type,
                order=order,
                block_size=block_size
            )


class HigherOrderRobertaLayer(RobertaLayer):
    """Drop-in replacement with stability fixes."""

    def __init__(self, config, order=2, block_size=32, layer=None):
        super().__init__(config)
        if layer is not None:
            self.attention = HigherOrderRobertaAttention(
                config,
                order=order,
                block_size=block_size,
                attn=layer.attention
            )
            self.intermediate = layer.intermediate
            self.output = layer.output
        else:
            self.attention = HigherOrderRobertaAttention(
                config, 
                order=order, 
                block_size=block_size
            )


class HigherOrderRobertaEncoder(RobertaEncoder):
    """RobertaEncoder with higher-order attention."""

    def __init__(self, config, order=2, block_size=32, original_encoder=None):
        super().__init__(config)
        if original_encoder is not None:
            self.layer = nn.ModuleList([
                HigherOrderRobertaLayer(
                    config, 
                    order=order, 
                    block_size=block_size,
                    layer=original_encoder.layer[i]
                )
                for i in range(config.num_hidden_layers)
            ])
        else:
            self.layer = nn.ModuleList([
                HigherOrderRobertaLayer(
                    config, 
                    order=order, 
                    block_size=block_size
                )
                for _ in range(config.num_hidden_layers)
            ])


class HigherOrderRobertaModel(RobertaModel):
    """RobertaModel with higher-order attention."""

    def __init__(self, config, order=2, block_size=32, original_model=None):
        super().__init__(config)
        if original_model is not None:
            self.embeddings = original_model.embeddings
            self.encoder = HigherOrderRobertaEncoder(
                config,
                order=order,
                block_size=block_size,
                original_encoder=original_model.encoder
            )
            self.pooler = original_model.pooler
        else:
            self.encoder = HigherOrderRobertaEncoder(
                config, 
                order=order, 
                block_size=block_size
            )

        self.post_init()


class HigherOrderRobertaForSequenceClassification(RobertaForSequenceClassification):
    """RobertaForSequenceClassification with higher-order attention."""

    def __init__(self, config, order=2, block_size=32, original_model=None):
        super().__init__(config)
        if original_model is not None:
            self.roberta = HigherOrderRobertaModel(
                config,
                order=order,
                block_size=block_size,
                original_model=original_model.roberta
            )
            self.classifier = original_model.classifier
        else:
            self.roberta = HigherOrderRobertaModel(
                config, 
                order=order, 
                block_size=block_size
            )

        self.post_init()


def load_pretrained_teacher_model(num_layers=6, num_classes=3):
    """Load and adapt pretrained RoBERTa model as teacher."""
    print("Loading pretrained RoBERTa model...")

    model = RobertaForSequenceClassification.from_pretrained("roberta-base")
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = num_layers
    config.hidden_dropout_prob = 0.1  # Reduced dropout
    config.attention_probs_dropout_prob = 0.1  # Reduced dropout
    config.num_labels = num_classes

    new_model = RobertaForSequenceClassification(config)
    new_model.roberta.embeddings = model.roberta.embeddings
    
    for i in range(num_layers):
        new_model.roberta.encoder.layer[i] = model.roberta.encoder.layer[i]

    new_model.roberta.pooler = model.roberta.pooler
    new_model.classifier = RobertaClassificationHead(config)

    print(f"Created teacher model with {num_layers} layers and {num_classes} output classes")
    return new_model, config


def create_higher_order_student_model(config, teacher_model, order=2, block_size=16):
    """Create higher-order student model with smaller block size for stability."""
    print(f"Creating higher-order student model (order={order}, block_size={block_size}) from teacher...")

    student_model = HigherOrderRobertaForSequenceClassification(
        config,
        order=order,
        block_size=block_size,  # Smaller block size for stability
        original_model=teacher_model
    )

    student_model.classifier = teacher_model.classifier
    print(f"Successfully created higher-order student model")
    return student_model


def load_mnli_data(batch_size=16, max_length=128, split="train"):
    from datasets import load_dataset
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    dataset = load_dataset("glue", "mnli")

    def preprocess(example):
        return tokenizer(
            example["premise"], example["hypothesis"],
            padding="max_length", truncation=True, max_length=max_length
        )

    tokenized_train = dataset["train"].map(preprocess, batched=True)
    tokenized_val = dataset["validation_matched"].map(preprocess, batched=True)

    tokenized_train = tokenized_train.rename_column("label", "labels")
    tokenized_val = tokenized_val.rename_column("label", "labels")

    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    tokenized_val.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(tokenized_val, batch_size=batch_size)

    label_map = {
        0: "entailment",
        1: "neutral",
        2: "contradiction"
    }

    print(f"MNLI Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_val)}")
    return tokenizer, train_loader, val_loader, label_map




def finetune_teacher_model(model, train_loader, val_loader, epochs=5, num_classes=4, label_map=None):
    """Finetune teacher model with stability improvements."""
    print(f"Finetuning teacher model for up to {epochs} epochs...")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,  # Reduced weight decay
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)  # Reduced learning rate

    # Simpler scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    loss_fn = nn.CrossEntropyLoss()
    best_val_acc = 0.0
    patience = 2
    no_improvement = 0

    if label_map is None:
        label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}


    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Teacher Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = loss_fn(logits, labels)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": optimizer.param_groups[0]['lr']
            })

        scheduler.step()
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                loss = loss_fn(logits, labels)
                val_loss += loss.item() * labels.size(0)

                preds = torch.argmax(logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in range(num_classes):
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total

        print(f"Teacher Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in range(num_classes):
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            no_improvement = 0
            torch.save(model.state_dict(), "results/models/roberta_teacher_best.pt")
            print(f"  New best teacher model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    model.load_state_dict(torch.load("results/models/roberta_teacher_best.pt"))
    print(f"Loaded best teacher model with val accuracy: {best_val_acc:.4f}")

    return model, best_val_acc


def train_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=5,
        learning_rate=5e-6,  # Much smaller learning rate
        weight_decay=0.01,
        alpha=0.3,  # Less weight on distillation
        temperature=2.0,  # Lower temperature
        num_classes=4
):
    """Train student model with improved stability."""
    print(f"Training student model with distillation for up to {epochs} epochs...")

    # Much more conservative optimizer settings
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Simple scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8)

    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")

    # Training tracking
    best_val_acc = 0.0
    patience = 2
    no_improvement = 0
    label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}


    for epoch in range(epochs):
        # Training
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Student Epoch {epoch + 1}/{epochs}")
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()

            # Forward pass - student
            try:
                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits
                
                # Check for NaN in student outputs
                if torch.isnan(student_logits).any():
                    print(f"NaN detected in student logits at batch {batch_idx}")
                    continue
                    
            except Exception as e:
                print(f"Error in student forward pass at batch {batch_idx}: {e}")
                continue

            # Forward pass - teacher
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Cross-entropy loss
            ce_loss = ce_loss_fn(student_logits, labels)
            
            # Check for NaN in CE loss
            if torch.isnan(ce_loss):
                print(f"NaN in CE loss at batch {batch_idx}")
                continue

            # Distillation loss with stability checks
            try:
                student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                
                # Clamp to prevent extreme values
                student_logits_soft = torch.clamp(student_logits_soft, min=-50, max=50)
                teacher_logits_soft = torch.clamp(teacher_logits_soft, min=1e-8, max=1.0)
                
                kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
                
                if torch.isnan(kl_loss):
                    print(f"NaN in KL loss at batch {batch_idx}, using CE loss only")
                    kl_loss = torch.tensor(0.0, device=ce_loss.device)
                    
            except Exception as e:
                print(f"Error in KL loss computation at batch {batch_idx}: {e}")
                kl_loss = torch.tensor(0.0, device=ce_loss.device)

            # Combined loss
            loss = (1 - alpha) * ce_loss + alpha * kl_loss
            
            # Final NaN check
            if torch.isnan(loss):
                print(f"NaN in combined loss at batch {batch_idx}")
                continue

            # Backward pass with gradient clipping
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=0.5)  # Aggressive clipping
            
            # Check gradients for NaN
            has_nan_grad = False
            for name, param in student_model.named_parameters():
                if param.grad is not None and torch.isnan(param.grad).any():
                    print(f"NaN gradient in {name}")
                    has_nan_grad = True
                    break
            
            if not has_nan_grad:
                optimizer.step()

            # Track metrics
            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(student_logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": optimizer.param_groups[0]['lr']
            })

        scheduler.step()
        epoch_train_loss = train_loss / max(train_total, 1)
        epoch_train_acc = train_correct / max(train_total, 1)

        # Validation
        student_model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)

                try:
                    student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                    student_logits = student_outputs.logits

                    teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                    teacher_logits = teacher_outputs.logits

                    # Losses
                    ce_loss = ce_loss_fn(student_logits, labels)
                    
                    student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                    teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                    kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)
                    
                    loss = (1 - alpha) * ce_loss + alpha * kl_loss
                    val_loss += loss.item() * labels.size(0)

                    # Accuracy
                    preds = torch.argmax(student_logits, dim=-1)
                    val_correct += (preds == labels).sum().item()
                    val_total += labels.size(0)

                    # Per-class accuracy
                    for i in range(num_classes):
                        class_mask = (labels == i)
                        class_count = class_mask.sum().item()
                        if class_count > 0:
                            class_correct[i] += ((preds == labels) & class_mask).sum().item()
                            class_total[i] += class_count
                            
                except Exception as e:
                    print(f"Error in validation: {e}")
                    continue

        epoch_val_loss = val_loss / max(val_total, 1)
        epoch_val_acc = val_correct / max(val_total, 1)

        print(f"Student Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in range(num_classes):
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            no_improvement = 0
            torch.save(student_model.state_dict(), "results/models/roberta_student_best.pt")
            print(f"  New best student model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    if os.path.exists("results/models/roberta_student_best.pt"):
        student_model.load_state_dict(torch.load("results/models/roberta_student_best.pt"))
        print(f"Loaded best student model with val accuracy: {best_val_acc:.4f}")
    
    return student_model, best_val_acc


def main_with_stability_fixes():
    """Main execution with stability improvements."""
    print("\n=== Fixed Higher-Order RoBERTa with Knowledge Distillation ===\n")

    # Conservative parameters
    NUM_LAYERS = 6
    BATCH_SIZE = 8  # Smaller batch size
    MAX_LENGTH = 384  # Shorter sequences
    MAX_SAMPLES = 2000  # Fewer samples for testing
    TEACHER_EPOCHS = 3
    STUDENT_EPOCHS = 3
    DISTILLATION_ALPHA = 0.5
    TEMPERATURE = 1.0
    ATTENTION_ORDER = 2
    NUM_CLASSES = 3
    BLOCK_SIZE = 32  # Smaller block size

    print(f"Using conservative parameters for stability:")
    print(f"  Block Size: {BLOCK_SIZE}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Max Length: {MAX_LENGTH}")
    print(f"  Max Samples: {MAX_SAMPLES}")

    # Load data
    tokenizer, train_loader, val_loader, label_map = load_mnli_data(
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH
    )

    # Load teacher model
    teacher_model, roberta_config = load_pretrained_teacher_model(
        num_layers=NUM_LAYERS,
        num_classes=NUM_CLASSES
    )
    teacher_model.to(device)

    # Train teacher
    print("\n=== Step 1: Finetuning RoBERTa Teacher Model ===\n")
    teacher_model, teacher_acc = finetune_teacher_model(
        teacher_model,
        train_loader,
        val_loader,
        epochs=TEACHER_EPOCHS,
        num_classes=NUM_CLASSES,
        label_map=label_map
    )

    # Create and train student
    print("\n=== Step 2: Training Higher-Order RoBERTa Student Model with Distillation ===\n")
    student_model = create_higher_order_student_model(
        roberta_config,
        teacher_model,
        order=ATTENTION_ORDER,
        block_size=BLOCK_SIZE
    )
    student_model.to(device)

    student_model, best_val_acc = train_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=STUDENT_EPOCHS,
        learning_rate=5e-6,
        weight_decay=0.01,
        alpha=DISTILLATION_ALPHA,
        temperature=TEMPERATURE,
        num_classes=NUM_CLASSES
    )

    print("\n=== Training Complete ===")
    print(f"Best student model validation accuracy: {best_val_acc:.4f}")
    print(f"Models saved to: results/models/")

    # Final comparison
    print("\n=== Comparing Teacher vs Student Performance ===")
    teacher_model.eval()
    student_model.eval()

    teacher_correct = 0
    student_correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating Models"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            try:
                # Teacher predictions
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)
                teacher_correct += (teacher_preds == labels).sum().item()

                # Student predictions
                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_preds = torch.argmax(student_outputs.logits, dim=-1)
                student_correct += (student_preds == labels).sum().item()

                total += labels.size(0)
            except Exception as e:
                print(f"Error in evaluation: {e}")
                continue

    if total > 0:
        teacher_acc = teacher_correct / total
        student_acc = student_correct / total

        print(f"Final RoBERTa Teacher Accuracy: {teacher_acc:.4f}")
        print(f"Final Higher-Order Student Accuracy: {student_acc:.4f}")
        print(f"Difference: {(student_acc - teacher_acc):.4f}")
    else:
        print("No valid evaluations completed.")




if __name__ == "__main__":
    main_with_stability_fixes()
