import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertForSequenceClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback,
)
from torch.optim import AdamW
import numpy as np
from sklearn.metrics import accuracy_score

# Configuration
MODEL_NAME = "bert-base-uncased"
MAX_LENGTH = 128
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 2e-5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
dataset = load_dataset("ag_news")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


# Preprocess function
def tokenize_function(examples):
    return tokenizer(
        examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH
    )


# Tokenize dataset
dataset = dataset.map(tokenize_function, batched=True)
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])

# Split dataset
train_val = dataset["train"].train_test_split(test_size=0.1, seed=42)
train_dataset = train_val["train"]
val_dataset = train_val["test"]
test_dataset = dataset["test"]

# Model setup
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(
    device
)


# Early stopping callback
class EarlyStoppingCallback(TrainerCallback):
    def __init__(self, early_stopping_patience=2):
        self.early_stopping_patience = early_stopping_patience
        self.best_metric = None
        self.patience_counter = 0

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        current_metric = metrics.get("eval_accuracy")
        if self.best_metric is None:
            self.best_metric = current_metric
        elif current_metric > self.best_metric:
            self.best_metric = current_metric
            self.patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), "best_model.pt")
        else:
            self.patience_counter += 1
            if self.patience_counter >= self.early_stopping_patience:
                control.should_training_stop = True


# Compute metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}


# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    load_best_model_at_end=False,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

# Train model
trainer.train()

# Load best model
model.load_state_dict(torch.load("best_model.pt"))

# Final evaluation
test_results = trainer.predict(test_dataset)
print(f"Test Accuracy: {test_results.metrics['test_accuracy']:.4f}")
