from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from nlp_training.exp_data import get_exp_data

def evaluate_model(model_name_or_path, batch_size=32, device=None):
    """
    Loads a pretrained Hugging Face sequence classification model and tokenizer,
    evaluates it on the dataset returned by get_dataset_fn, and prints performance metrics.

    Args:
        model_name_or_path (str): Path or model ID to pass to from_pretrained()
        get_dataset_fn (callable): Function returning a dict with keys 'text' and 'label'
        batch_size (int): Batch size for evaluation DataLoader
        device (str, optional): e.g. "cuda" or "cpu". Defaults to GPU if available.
    Returns:
        dict: {'accuracy', 'precision', 'recall', 'f1'}
    """
    # 1. Setup device
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # 2. Load model & tokenizer
    model     = AutoModelForSequenceClassification.from_pretrained(model_name_or_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    # 3. Prepare data
    texts, labels, num_classes = get_exp_data(ds_name='ag_news', seed=42, num_samples=1000)                # expect {'text': [...], 'label': [...]}
    # texts, labels = data['text'], data['label']
    texts = [text if isinstance(text, str) else " ".join(text) for text in texts]  # Ensure text is a string
    print(f"Loaded {len(texts)} samples with {num_classes} classes")
    print(f"text is of type {type(texts)} with {len(texts)} samples")
    print(f"texts[0]: {texts[0]}")
    encodings = tokenizer(
        texts,
        truncation=True,
        padding=True,
        return_tensors="pt",
        max_length=256,
    )
    input_ids      = encodings["input_ids"]
    attention_mask = encodings["attention_mask"]
    labels_tensor  = torch.tensor(labels)

    ds = TensorDataset(input_ids, attention_mask, labels_tensor)
    loader = DataLoader(ds, batch_size=batch_size)

    # 4. Inference
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for ids, masks, labs in loader:
            ids, masks = ids.to(device), masks.to(device)
            outputs = model(ids, attention_mask=masks)
            preds = torch.argmax(outputs.logits, dim=-1).cpu().tolist()
            all_preds.extend(preds)
            all_labels.extend(labs.tolist())

    # 5. Metrics
    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="weighted"
    )

    # 6. Report
    print("Evaluation Results:")
    print(f"  Accuracy : {acc:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall   : {recall:.4f}")
    print(f"  F1 Score : {f1:.4f}\n")
    print("Classification Report:")
    print(classification_report(all_labels, all_preds))

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

# Example of how to call it:
if __name__ == "__main__":
    metrics = evaluate_model("../data_files/TextFooler/saved_model/ft/model_gpt2_ds_ag_news_train_epoch_1_run_20250717_224152_seed_42/final_model")
    print(metrics)
