

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm

#  datasets cache to avoid issues
import datasets
datasets.disable_caching()

from datasets import load_dataset, concatenate_datasets

#  imports for compatible versions
from transformers import (
    RobertaTokenizer,
    RobertaModel,
    RobertaConfig,
    RobertaForSequenceClassification
)

from transformers.models.roberta.modeling_roberta import (
    RobertaClassificationHead
)

from torch.utils.data import DataLoader
from torch.optim import AdamW

# Force CPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class VanillaStudentRoBERTa(RobertaForSequenceClassification):
    """
    Vanilla student RoBERTa model - just a smaller version of the teacher.
    No complex attention mechanisms, just standard RoBERTa architecture.
    """

    def __init__(self, config):
        super().__init__(config)
        # Standard RoBERTa initialization
        self.post_init()

    @classmethod
    def create_student_config(cls, teacher_config, hidden_size_ratio=0.5, num_layers_ratio=0.5):
        """Create a smaller student configuration based on teacher."""
        student_config = RobertaConfig.from_dict(teacher_config.to_dict())

        # Make student smaller
        student_config.hidden_size = int(teacher_config.hidden_size * hidden_size_ratio)
        student_config.num_hidden_layers = int(teacher_config.num_hidden_layers * num_layers_ratio)
        student_config.intermediate_size = int(teacher_config.intermediate_size * hidden_size_ratio)
        student_config.num_attention_heads = max(1, int(teacher_config.num_attention_heads * hidden_size_ratio))

        # Ensure hidden_size is divisible by num_attention_heads
        while student_config.hidden_size % student_config.num_attention_heads != 0:
            student_config.num_attention_heads -= 1

        return student_config


def load_pretrained_teacher_model(num_layers=6, num_classes=2):
    """Load and adapt pretrained RoBERTa model as teacher."""
    print("Loading pretrained RoBERTa model...")

    model = RobertaForSequenceClassification.from_pretrained("roberta-base")
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = num_layers
    config.hidden_dropout_prob = 0.1  # Reduced dropout
    config.attention_probs_dropout_prob = 0.1  # Reduced dropout
    config.num_labels = num_classes

    new_model = RobertaForSequenceClassification(config)
    new_model.roberta.embeddings = model.roberta.embeddings

    for i in range(num_layers):
        new_model.roberta.encoder.layer[i] = model.roberta.encoder.layer[i]

    new_model.roberta.pooler = model.roberta.pooler
    new_model.classifier = RobertaClassificationHead(config)

    print(f"Created teacher model with {num_layers} layers and {num_classes} output classes")
    return new_model, config


def create_vanilla_student_model(teacher_config, hidden_size_ratio=0.5, num_layers_ratio=0.5):
    """Create vanilla student model - just smaller RoBERTa."""
    print(f"Creating vanilla student model (hidden_ratio={hidden_size_ratio}, layer_ratio={num_layers_ratio})...")

    # Create student config
    student_config = VanillaStudentRoBERTa.create_student_config(
        teacher_config, hidden_size_ratio, num_layers_ratio
    )

    # Create student model
    student_model = VanillaStudentRoBERTa(student_config)

    print(f"Student model specs:")
    print(f"  Hidden size: {student_config.hidden_size} (vs teacher: {teacher_config.hidden_size})")
    print(f"  Layers: {student_config.num_hidden_layers} (vs teacher: {teacher_config.num_hidden_layers})")
    print(f"  Attention heads: {student_config.num_attention_heads} (vs teacher: {teacher_config.num_attention_heads})")

    return student_model, student_config


def load_sst2_data(batch_size=16, max_length=128):
    """Load SST-2 dataset."""
    from datasets import load_dataset
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    dataset = load_dataset("glue", "sst2")

    def preprocess(example):
        return tokenizer(
            example["sentence"], padding="max_length", truncation=True, max_length=max_length
        )

    tokenized_train = dataset["train"].map(preprocess, batched=True)
    tokenized_test = dataset["validation"].map(preprocess, batched=True)

    tokenized_train = tokenized_train.rename_column("label", "qa_class")
    tokenized_test = tokenized_test.rename_column("label", "qa_class")

    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])
    tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "qa_class"])

    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(tokenized_test, batch_size=batch_size)

    label_map = {
        0: "Negative",
        1: "Positive"
    }

    print(f"SST-2 Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_test)}")

    return tokenizer, train_loader, val_loader, label_map


def finetune_teacher_model(model, train_loader, val_loader, epochs=5, num_classes=2, label_map=None):
    """Finetune teacher model with stability improvements."""
    print(f"Finetuning teacher model for up to {epochs} epochs...")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)

    # Simple scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    loss_fn = nn.CrossEntropyLoss()
    best_val_acc = 0.0
    patience = 2
    no_improvement = 0

    if label_map is None:
        label_map = {0: "Negative", 1: "Positive"}

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Teacher Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = loss_fn(logits, labels)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total,
                "lr": optimizer.param_groups[0]['lr']
            })

        scheduler.step()
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["qa_class"].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                loss = loss_fn(logits, labels)
                val_loss += loss.item() * labels.size(0)

                preds = torch.argmax(logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in range(num_classes):
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total

        print(f"Teacher Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
        print(f"  Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in range(num_classes):
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            no_improvement = 0
            torch.save(model.state_dict(), "results/models/roberta_teacher_vanilla_best.pt")
            print(f"  New best teacher model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    model.load_state_dict(torch.load("results/models/roberta_teacher_vanilla_best.pt"))
    print(f"Loaded best teacher model with val accuracy: {best_val_acc:.4f}")

    return model, best_val_acc


def train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=5,
        learning_rate=1e-4,
        weight_decay=0.01,
        alpha=0.5,  # Balance between hard and soft targets
        temperature=4.0,
        num_classes=2
):
    """Train vanilla student model with knowledge distillation."""
    print(f"Training vanilla student model with distillation for up to {epochs} epochs...")

    # Optimizer setup
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)

    # Scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8)

    # Loss functions
    ce_loss_fn = nn.CrossEntropyLoss()
    kl_loss_fn = nn.KLDivLoss(reduction="batchmean")

    # Training tracking
    best_val_acc = 0.0
    patience = 2
    no_improvement = 0
    label_map = {0: "Negative", 1: "Positive"}

    for epoch in range(epochs):
        # Training
        student_model.train()
        teacher_model.eval()

        train_loss = 0
        train_ce_loss = 0
        train_kl_loss = 0
        train_correct = 0
        train_total = 0

        progress_bar = tqdm(train_loader, desc=f"Student Epoch {epoch + 1}/{epochs}")
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            optimizer.zero_grad()

            # Forward pass - student
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            # Forward pass - teacher (no gradients)
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            # Hard target loss (cross-entropy with true labels)
            ce_loss = ce_loss_fn(student_logits, labels)

            # Soft target loss (knowledge distillation)
            student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
            teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
            kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)

            # Combined loss
            loss = (1 - alpha) * ce_loss + alpha * kl_loss

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
            optimizer.step()

            # Track metrics
            train_loss += loss.item() * labels.size(0)
            train_ce_loss += ce_loss.item() * labels.size(0)
            train_kl_loss += kl_loss.item() * labels.size(0)

            preds = torch.argmax(student_logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)

            progress_bar.set_postfix({
                "total_loss": loss.item(),
                "ce_loss": ce_loss.item(),
                "kl_loss": kl_loss.item(),
                "acc": train_correct / train_total,
                "lr": optimizer.param_groups[0]['lr']
            })

        scheduler.step()
        epoch_train_loss = train_loss / train_total
        epoch_train_ce_loss = train_ce_loss / train_total
        epoch_train_kl_loss = train_kl_loss / train_total
        epoch_train_acc = train_correct / train_total

        # Validation
        student_model.eval()
        val_loss = 0
        val_ce_loss = 0
        val_kl_loss = 0
        val_correct = 0
        val_total = 0
        class_correct = [0] * num_classes
        class_total = [0] * num_classes

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["qa_class"].to(device)

                # Student predictions
                student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
                student_logits = student_outputs.logits

                # Teacher predictions
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

                # Losses
                ce_loss = ce_loss_fn(student_logits, labels)

                student_logits_soft = F.log_softmax(student_logits / temperature, dim=-1)
                teacher_logits_soft = F.softmax(teacher_logits / temperature, dim=-1)
                kl_loss = kl_loss_fn(student_logits_soft, teacher_logits_soft) * (temperature ** 2)

                loss = (1 - alpha) * ce_loss + alpha * kl_loss

                val_loss += loss.item() * labels.size(0)
                val_ce_loss += ce_loss.item() * labels.size(0)
                val_kl_loss += kl_loss.item() * labels.size(0)

                # Accuracy
                preds = torch.argmax(student_logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                # Per-class accuracy
                for i in range(num_classes):
                    class_mask = (labels == i)
                    class_count = class_mask.sum().item()
                    if class_count > 0:
                        class_correct[i] += ((preds == labels) & class_mask).sum().item()
                        class_total[i] += class_count

        epoch_val_loss = val_loss / val_total
        epoch_val_ce_loss = val_ce_loss / val_total
        epoch_val_kl_loss = val_kl_loss / val_total
        epoch_val_acc = val_correct / val_total

        print(f"Student Epoch {epoch + 1}/{epochs}:")
        print(f"  Train - Total: {epoch_train_loss:.4f}, CE: {epoch_train_ce_loss:.4f}, KL: {epoch_train_kl_loss:.4f}, Acc: {epoch_train_acc:.4f}")
        print(f"  Val   - Total: {epoch_val_loss:.4f}, CE: {epoch_val_ce_loss:.4f}, KL: {epoch_val_kl_loss:.4f}, Acc: {epoch_val_acc:.4f}")

        # Print per-class accuracy
        print("  Per-class validation accuracy:")
        for i in range(num_classes):
            if class_total[i] > 0:
                class_acc = class_correct[i] / class_total[i]
                class_name = label_map.get(i, f"Class {i}")
                print(f"    {class_name}: {class_acc:.4f} ({class_correct[i]}/{class_total[i]})")

        # Save best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            no_improvement = 0
            torch.save(student_model.state_dict(), "results/models/roberta_student_vanilla_best.pt")
            print(f"  New best student model saved! Val accuracy: {best_val_acc:.4f}")
        else:
            no_improvement += 1
            print(f"  No improvement for {no_improvement} epochs.")

            if no_improvement >= patience:
                print(f"  Early stopping at epoch {epoch + 1}.")
                break

    # Load best model
    if os.path.exists("results/models/roberta_student_vanilla_best.pt"):
        student_model.load_state_dict(torch.load("results/models/roberta_student_vanilla_best.pt"))
        print(f"Loaded best student model with val accuracy: {best_val_acc:.4f}")

    return student_model, best_val_acc


def count_parameters(model):
    """Count the number of parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def main_vanilla_distillation():
    """Main execution for vanilla student-teacher distillation."""
    print("\n=== Vanilla 2D Student-Teacher RoBERTa with Knowledge Distillation ===\n")

    # Configuration - same as original
    NUM_LAYERS = 3
    BATCH_SIZE = 8
    MAX_LENGTH = 64
    TEACHER_EPOCHS = 3
    STUDENT_EPOCHS = 3
    DISTILLATION_ALPHA = 0.5  # Balanced distillation
    TEMPERATURE = 4.0
    NUM_CLASSES = 2

    # Student model configuration
    HIDDEN_SIZE_RATIO = 0.5  # Student has half the hidden size
    NUM_LAYERS_RATIO = 0.67  # Student has 2/3 the layers

    print(f"Configuration:")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Max Length: {MAX_LENGTH}")
    print(f"  Teacher Layers: {NUM_LAYERS}")
    print(f"  Student Hidden Ratio: {HIDDEN_SIZE_RATIO}")
    print(f"  Student Layer Ratio: {NUM_LAYERS_RATIO}")
    print(f"  Distillation Alpha: {DISTILLATION_ALPHA}")
    print(f"  Temperature: {TEMPERATURE}")

    # Load data
    tokenizer, train_loader, val_loader, label_map = load_sst2_data(
        batch_size=BATCH_SIZE,
        max_length=MAX_LENGTH
    )

    # Load teacher model
    teacher_model, teacher_config = load_pretrained_teacher_model(
        num_layers=NUM_LAYERS,
        num_classes=NUM_CLASSES
    )
    teacher_model.to(device)

    teacher_params = count_parameters(teacher_model)
    print(f"Teacher model parameters: {teacher_params:,}")

    # Train teacher
    print("\n=== Step 1: Finetuning RoBERTa Teacher Model ===\n")
    teacher_model, teacher_acc = finetune_teacher_model(
        teacher_model,
        train_loader,
        val_loader,
        epochs=TEACHER_EPOCHS,
        num_classes=NUM_CLASSES,
        label_map=label_map
    )

    # Create vanilla student
    print("\n=== Step 2: Creating Vanilla Student Model ===\n")
    student_model, student_config = create_vanilla_student_model(
        teacher_config,
        hidden_size_ratio=HIDDEN_SIZE_RATIO,
        num_layers_ratio=NUM_LAYERS_RATIO
    )
    student_model.to(device)

    student_params = count_parameters(student_model)
    print(f"Student model parameters: {student_params:,}")
    print(f"Parameter reduction: {(1 - student_params/teacher_params)*100:.1f}%")

    # Train student with distillation
    print("\n=== Step 3: Training Vanilla Student with Knowledge Distillation ===\n")
    student_model, best_val_acc = train_vanilla_student_with_distillation(
        student_model,
        teacher_model,
        train_loader,
        val_loader,
        epochs=STUDENT_EPOCHS,
        learning_rate=1e-4,
        weight_decay=0.01,
        alpha=DISTILLATION_ALPHA,
        temperature=TEMPERATURE,
        num_classes=NUM_CLASSES
    )

    print("\n=== Training Complete ===")
    print(f"Best student model validation accuracy: {best_val_acc:.4f}")
    print(f"Models saved to: results/models/")

    # Final comparison
    print("\n=== Comparing Teacher vs Student Performance ===")
    teacher_model.eval()
    student_model.eval()

    teacher_correct = 0
    student_correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Final Evaluation"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["qa_class"].to(device)

            # Teacher predictions
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_preds = torch.argmax(teacher_outputs.logits, dim=-1)
            teacher_correct += (teacher_preds == labels).sum().item()

            # Student predictions
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_preds = torch.argmax(student_outputs.logits, dim=-1)
            student_correct += (student_preds == labels).sum().item()

            total += labels.size(0)

    teacher_final_acc = teacher_correct / total
    student_final_acc = student_correct / total

    print(f"\nFinal Results:")
    print(f"Teacher Model:")
    print(f"  Parameters: {teacher_params:,}")
    print(f"  Accuracy: {teacher_final_acc:.4f}")
    print(f"Student Model:")
    print(f"  Parameters: {student_params:,}")
    print(f"  Accuracy: {student_final_acc:.4f}")
    print(f"  Parameter Reduction: {(1 - student_params/teacher_params)*100:.1f}%")
    print(f"  Accuracy Drop: {(teacher_final_acc - student_final_acc)*100:.1f}%")

    # Calculate efficiency metric
    efficiency = student_final_acc / (student_params / teacher_params)
    print(f"  Efficiency (Acc/Param_Ratio): {efficiency:.2f}")

    return teacher_model, student_model, teacher_final_acc, student_final_acc


if __name__ == "__main__":
    main_vanilla_distillation()